Skip to content

Instantly share code, notes, and snippets.

@jeremyorme
Created November 18, 2020 06:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeremyorme/aaaf03535546bdbed2193c72ea3c03ca to your computer and use it in GitHub Desktop.
Save jeremyorme/aaaf03535546bdbed2193c72ea3c03ca to your computer and use it in GitHub Desktop.
Set based, monte-carlo tree search problem solver
import random
# You have a fox, a chicken and a sack of grain. You must cross a river with only one of them at a time. If you leave the fox with the chicken he will eat it; if you leave the chicken with the grain he will eat it. How can you get all three across safely?
problem_spec = {
'start_state': {
'near bank': {
'you',
'fox',
'chicken',
'sack of grain'},
'far bank': set()
},
'win_state': {
'far bank': {
'you',
'fox',
'chicken',
'sack of grain'},
'near bank': set()
},
'actions': {
'cross with fox': [{
'if': 'near bank',
'contains': ['you', 'fox'],
'then': [{
'modify': 'near bank',
'remove': ['you', 'fox']
}, {
'modify': 'far bank',
'add': ['you', 'fox']
}]
}, {
'if': 'far bank',
'contains': ['you', 'fox'],
'then': [{
'modify': 'far bank',
'remove': ['you', 'fox']
}, {
'modify': 'near bank',
'add': ['you', 'fox']
}]
}],
'cross with chicken': [{
'if': 'near bank',
'contains': ['you', 'chicken'],
'then': [{
'modify': 'near bank',
'remove': ['you', 'chicken']
}, {
'modify': 'far bank',
'add': ['you', 'chicken']
}]
}, {
'if': 'far bank',
'contains': ['you', 'chicken'],
'then': [{
'modify': 'far bank',
'remove': ['you', 'chicken']
}, {
'modify': 'near bank',
'add': ['you', 'chicken']
}]
}],
'cross with sack of grain': [{
'if': 'near bank',
'contains': ['you', 'sack of grain'],
'then': [{
'modify': 'near bank',
'remove': ['you', 'sack of grain']
}, {
'modify': 'far bank',
'add': ['you', 'sack of grain']
}]
}, {
'if': 'far bank',
'contains': ['you', 'sack of grain'],
'then': [{
'modify': 'far bank',
'remove': ['you', 'sack of grain']
}, {
'modify': 'near bank',
'add': ['you', 'sack of grain']
}]
}],
'cross alone': [{
'if': 'near bank',
'contains': ['you'],
'then': [{
'modify': 'near bank',
'remove': ['you']
}, {
'modify': 'far bank',
'add': ['you']
}]
}, {
'if': 'far bank',
'contains': ['you'],
'then': [{
'modify': 'far bank',
'remove': ['you']
}, {
'modify': 'near bank',
'add': ['you']
}]
}]
},
'consequences': {
'fox eats chicken': [{
'if': 'near bank',
'contains': ['fox', 'chicken'],
'not_contains': ['you'],
'then': [{
'modify': 'near bank',
'remove': ['chicken']
}]
}, {
'if': 'far bank',
'contains': ['fox', 'chicken'],
'not_contains': ['you'],
'then': [{
'modify': 'far bank',
'remove': ['chicken']
}]
}],
'chicken eats grain': [{
'if': 'near bank',
'contains': ['chicken', 'sack of grain'],
'not_contains': ['you'],
'then': [{
'modify': 'near bank',
'remove': ['sack of grain']
}]
}, {
'if': 'far bank',
'contains': ['chicken', 'sack of grain'],
'not_contains': ['you'],
'then': [{
'modify': 'far bank',
'remove': ['sack of grain']
}]
}]
}
}
class explored_move:
initial_wins = 1
initial_plays = 2
def __init__(self, num_actions, prev_move=None):
self.num_actions = num_actions
self.prev_move = prev_move
self.next_moves = {}
self.wins = explored_move.initial_wins
self.plays = explored_move.initial_plays
# play this move
def play(self):
self.plays += 1
# mark move as illegal (don't try again)
def illegal(self):
self.wins = 0
# this is a win - propagate up to root
def win(self):
self.wins += 1
if self.prev_move is not None:
self.prev_move.win()
# get number of wins for next move i
def num_wins(self, i):
return self.next_moves[i].wins if i in self.next_moves else explored_move.initial_wins
# get number of plays for next move i
def num_plays(self, i):
return self.next_moves[i].plays if i in self.next_moves else explored_move.initial_plays
# create a move following the this one
def play_next(self, i):
if i not in self.next_moves:
self.next_moves[i] = explored_move(self.num_actions, self)
next_move = self.next_moves[i]
next_move.play()
return next_move
# choose move weighted by previous wins
def choose_next_move(self):
# calculate a common denominator
d = 1
for m in range(self.num_actions):
if self.num_wins(m) > 0:
d *= self.num_plays(m)
# accumulate win ratio * common denominator
max_i = 0
for m in range(self.num_actions):
max_i += self.num_wins(m) * (d // self.num_plays(m))
# if all moves have zero wins then return -1 (no valid move)
if max_i == 0:
return -1
# select a random point in the weighted move probability space
j = random.randrange(0, max_i)
# accumulate the move intervals to determine the selected move m for index i
i = 0
for m in range(self.num_actions):
i += self.num_wins(m) * (d // self.num_plays(m))
if i > j:
return m
# should never get here unless the above maths went wrong!
raise Exception('Unreachable: ' + str(i) + ' > ' + str(max_i))
# return the best of the explored next moves
def best_move(self):
# calculate a common denominator
d = 1
for i in range(self.num_actions):
if self.num_wins(i) > 0:
d *= self.num_plays(i)
# start with ratio zero and no valid move
best_ratio = 0
best_m = (-1, None)
# for each move if the ratio is better store it as the new best along with the best move
for item in self.next_moves.items():
m, node = item
ratio = node.wins * (d // node.plays)
if ratio > best_ratio:
best_ratio = ratio
best_m = (m, node)
# return the best move
return best_m
def apply_action(state, action):
for a in action:
# check for match
match = True
if 'contains' in a:
for c in a['contains']:
if c not in state[a['if']]:
match = False
break
if 'not_contains' in a:
for n in a['not_contains']:
if n in state[a['if']]:
match = False
break
if not match:
continue
# apply modifications
for m in a['then']:
if 'remove' in m:
for x in m['remove']:
state[m['modify']].remove(x)
elif 'add' in m:
for x in m['add']:
state[m['modify']].add(x)
return True
return False
def copy_state(state):
new_state = {}
for name, contents in state.items():
new_state[name] = set(contents)
return new_state
def equal_state(left, right):
keys = set()
keys.update(list(left.keys()) + list(right.keys()))
for k in keys:
if k not in right.keys() or k not in left.keys() or left[k] != right[k]:
return False
return True
def try_play_move(actions, consequences, cur_move, state):
# try to select a move
i = cur_move.choose_next_move()
# if no valid move return Nones
if i == -1:
return None, None
# make the move
next_move = cur_move.play_next(i)
next_state = copy_state(state)
if not apply_action(next_state, actions[i]):
next_move.illegal()
return cur_move, state
# apply consequences
for c in consequences:
apply_action(next_state, c)
return next_move, next_state
# play with rules to reach output from input
def play(spec, paths, max_depth=1000):
actions = [a for _,a in spec['actions'].items()]
consequences = [c for _,c in spec['consequences'].items()]
start_state = spec['start_state']
win_state = spec['win_state']
# start with root move
root = explored_move(len(actions))
# explore random paths
for p in range(paths):
# reset state for new path
state = start_state
cur_move = root
# make depth number of moves
for d in range(max_depth):
# play a random move
next_move, next_state = try_play_move(actions, consequences, cur_move, state)
# if no legal move then stop this path
if next_move is None:
break
# illegal move
if next_move == cur_move:
continue
# check if matches any win state
if equal_state(next_state, win_state):
next_move.win()
break
# move on to next move
cur_move = next_move
state = next_state
# replay move sequence with most wins
route = []
i, node = root.best_move()
while node is not None and node.wins > 1:
route.append(i)
i, node = node.best_move()
# return the route
return route
# go!
route = play(problem_spec, 500)
print(route)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment