8puzzle.py
changeset 1 94fc07d1e2e1
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/8puzzle.py	Thu May 26 23:26:48 2011 +0200
     1.3 @@ -0,0 +1,112 @@
     1.4 +"""
     1.5 +Name: A* on 8 puzzle
     1.6 +Author: Eugen Sawin <sawine@informatik.uni-freiburg.de>
     1.7 +"""
     1.8 +import heapq
     1.9 +
    1.10 +def main():    
    1.11 +    goal = (1, 2, 3, 
    1.12 +            8, 0, 4, 
    1.13 +            7, 6, 5)
    1.14 +    init = (2, 8, 3, 
    1.15 +            1, 6, 4, 
    1.16 +            7, 0, 5)
    1.17 +    print "heuristic: none"
    1.18 +    search(init, goal, Heuristic(goal))
    1.19 +    print "\nheuristic: misplaced tiles"
    1.20 +    search(init, goal, MisplacedTiles(goal))
    1.21 +    print "\nheuristic: manhattan distance"
    1.22 +    search(init, goal, ManhattanDist(goal))
    1.23 +
    1.24 +def search(init, goal, h): 
    1.25 +    frontier = []
    1.26 +    print init
    1.27 +    print goal
    1.28 +    heapq.heapify(frontier)
    1.29 +    node = Node(init, None, 0, 0)
    1.30 +    heapq.heappush(frontier, (0, node))
    1.31 +    explored = set()   
    1.32 +    while len(frontier):
    1.33 +        k, node = heapq.heappop(frontier)
    1.34 +        print k
    1.35 +        explored.add(node.state) 
    1.36 +        if node.state == goal:
    1.37 +            print "goal found in %i steps" % len(explored)
    1.38 +            break         
    1.39 +        for child in expand(node, h):  
    1.40 +            #print child
    1.41 +            #print
    1.42 +            heapq.heappush(frontier, (child.h_cost, child))
    1.43 +        
    1.44 +def expand(node, h):
    1.45 +    def swap(i1, i2):
    1.46 +        state = list(node.state)       
    1.47 +        a = state[i1]
    1.48 +        b = state[i2]
    1.49 +        state[i1] = b
    1.50 +        state[i2] = a
    1.51 +        return tuple(state)  
    1.52 +
    1.53 +    children = []
    1.54 +    cost = node.cost + 1
    1.55 +    for i, v in enumerate(node.state):
    1.56 +        if v == 0:
    1.57 +            if i > 0:
    1.58 +                s = swap(i, i-1)
    1.59 +                children.append(Node(s, node, cost,  h(s) + 1))
    1.60 +            if i < 8:
    1.61 +                s = swap(i, i+1)               
    1.62 +                children.append(Node(s, node, cost,  h(s) + 1)) 
    1.63 +            if i > 2: 
    1.64 +                s = swap(i, i-3)              
    1.65 +                children.append(Node(s, node, cost,  h(s) + 1))  
    1.66 +            if i < 6:
    1.67 +                s = swap(i, i+3)
    1.68 +                children.append(Node(s, node, cost,  h(s) + 1))             
    1.69 +    return children      
    1.70 +
    1.71 +class Node(object):
    1.72 +    def __init__(self, state, parent, cost, h_cost):
    1.73 +        self.state = state
    1.74 +        self.parent = parent
    1.75 +        self.cost = cost 
    1.76 +        self.h_cost = h_cost
    1.77 +    def __less__(self, node):
    1.78 +        return self.h_cost < node.h_cost
    1.79 +    def __str__(self):
    1.80 +        return "\n".join((str(self.state[:3]), 
    1.81 +                          str(self.state[3:6]), 
    1.82 +                          str(self.state[6:])))
    1.83 +
    1.84 +class Heuristic(object):
    1.85 +    def __init__(self, goal):
    1.86 +        self.goal = goal
    1.87 +    def __call__(self, state):
    1.88 +        return 0
    1.89 +
    1.90 +class MisplacedTiles(Heuristic):
    1.91 +   def __call__(self, state):       
    1.92 +       h = 0
    1.93 +       for i, s in enumerate(state):
    1.94 +           if self.goal[i] != s:
    1.95 +               h += 1       
    1.96 +       return h
    1.97 +
    1.98 +class ManhattanDist(Heuristic):
    1.99 +   def __call__(self, state):       
   1.100 +       h = 0
   1.101 +       for i, s in enumerate(state):
   1.102 +           for i2, s2 in enumerate(self.goal):
   1.103 +               if s == s2:
   1.104 +                   d = abs(i-i2)
   1.105 +                   h += d%3 + int(d/3)      
   1.106 +       return h
   1.107 +    
   1.108 +from argparse import ArgumentParser
   1.109 +
   1.110 +def parse_arguments():
   1.111 +    parser = ArgumentParser(description="8 puzzle")   
   1.112 +    return parser.parse_args()
   1.113 +
   1.114 +if __name__ == "__main__":
   1.115 +    main()