Skip to content

Instantly share code, notes, and snippets.

@DomNomNom
Created June 8, 2013 17:06
Show Gist options
  • Save DomNomNom/5735893 to your computer and use it in GitHub Desktop.
Save DomNomNom/5735893 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import networkx as nx
G = nx.Graph()
wolf = 'wolf'
goat = 'goat'
rose = 'rose'
objects = [wolf, goat, rose]
boat = 'boat'
shore_l = 'shore_l'
shore_r = 'shore_r'
transitions = { # given a position, where we can move to
shore_l : [boat],
shore_r : [boat],
boat : [shore_r, shore_l],
}
# A state of the world.
# It has a linked history and contains which things are where
class State(object):
def __init__(self, prevstate, **objects):
assert len(objects) == 3
self.prevstate = prevstate
self.objects = objects # note: should be treated as immutable.
self.moves = prevstate.moves+1 if prevstate else 0 # how many actions did we take to get here
def __hash__(self):
return hash(frozenset(self.objects.items()))
def __eq__(self, other):
return self.objects == other.objects
def __str__(self):
# out = ''
# if self.prevstate:
# out += str(self.prevstate) + '\n'
# out += repr(self.objects)
# out += ''
# return out
out = ''
out += ' '.join([ obj for obj,pos in self.objects.items() if pos==shore_l ]) + ' | '
out += ' '.join([ obj for obj,pos in self.objects.items() if pos==boat ]) + ' | '
out += ' '.join([ obj for obj,pos in self.objects.items() if pos==shore_r ])
return out
# return str(self.moves)
def isvalid(self): # note: the goal state technically is not valid
return all([
sum([1 for pos in self.objects.values() if pos==boat]) <= 1, # only 1 passenger on the boat
# make sure nothing gets eaten
self.objects[goat] != self.objects[rose],
self.objects[goat] != self.objects[wolf],
])
# all reachable states from this state
def reachables(self):
# for all objects, yield all states where it has moved
# it can either swap with something at the destination
# or just move there (swap with None)
for obj, pos in self.objects.items():
for destination in transitions[pos]:
for swapwith in [ obj2 for obj2, pos2 in self.objects.items() if pos2==destination] + [None]:
yield self.swap(obj, pos, swapwith, destination)
# returns a new state where obj_a in
# if obj_b==None, it is ignored
def swap(self, obj_a, pos_a, obj_b, pos_b):
newobjects = dict(self.objects)
newobjects[obj_a] = pos_b
if obj_b:
newobjects[obj_b] = pos_a
return State(self, **newobjects)
def astar(starts, goals):
exploredstates = set()
toexplore = starts
while toexplore:
state = toexplore.pop(0)
if state not in exploredstates:
# print('exploring {0}'.format(state))
# print()
exploredstates.add(state)
for newstate in state.reachables():
if newstate in goals or newstate.isvalid():
G.add_node(newstate)
G.add_edge(state, newstate)
if newstate in goals:
#return newstate
print(newstate)
# print()
elif newstate.isvalid() and state.isvalid():
toexplore.append(newstate) # TODO: insert ordered
return None # no path found
best = astar(
[ # starts
# State(None, wolf=boat, goat=shore_l, rose=shore_l),
State(None, wolf=shore_l, goat=boat, rose=shore_l),
# State(None, wolf=shore_l, goat=shore_l, rose=boat ),
],
{ # goals
State(None, wolf=shore_r, goat=shore_r, rose=shore_r),
}
)
print(best)
nx.draw_spring(G)
plt.show()
print('done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment