Created
November 18, 2020 06:23
-
-
Save jeremyorme/aaaf03535546bdbed2193c72ea3c03ca to your computer and use it in GitHub Desktop.
Set based, monte-carlo tree search problem solver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
# You have a fox, a chicken and a sack of grain. You must cross a river with only one of them at a time. If you leave the fox with the chicken he will eat it; if you leave the chicken with the grain he will eat it. How can you get all three across safely? | |
problem_spec = { | |
'start_state': { | |
'near bank': { | |
'you', | |
'fox', | |
'chicken', | |
'sack of grain'}, | |
'far bank': set() | |
}, | |
'win_state': { | |
'far bank': { | |
'you', | |
'fox', | |
'chicken', | |
'sack of grain'}, | |
'near bank': set() | |
}, | |
'actions': { | |
'cross with fox': [{ | |
'if': 'near bank', | |
'contains': ['you', 'fox'], | |
'then': [{ | |
'modify': 'near bank', | |
'remove': ['you', 'fox'] | |
}, { | |
'modify': 'far bank', | |
'add': ['you', 'fox'] | |
}] | |
}, { | |
'if': 'far bank', | |
'contains': ['you', 'fox'], | |
'then': [{ | |
'modify': 'far bank', | |
'remove': ['you', 'fox'] | |
}, { | |
'modify': 'near bank', | |
'add': ['you', 'fox'] | |
}] | |
}], | |
'cross with chicken': [{ | |
'if': 'near bank', | |
'contains': ['you', 'chicken'], | |
'then': [{ | |
'modify': 'near bank', | |
'remove': ['you', 'chicken'] | |
}, { | |
'modify': 'far bank', | |
'add': ['you', 'chicken'] | |
}] | |
}, { | |
'if': 'far bank', | |
'contains': ['you', 'chicken'], | |
'then': [{ | |
'modify': 'far bank', | |
'remove': ['you', 'chicken'] | |
}, { | |
'modify': 'near bank', | |
'add': ['you', 'chicken'] | |
}] | |
}], | |
'cross with sack of grain': [{ | |
'if': 'near bank', | |
'contains': ['you', 'sack of grain'], | |
'then': [{ | |
'modify': 'near bank', | |
'remove': ['you', 'sack of grain'] | |
}, { | |
'modify': 'far bank', | |
'add': ['you', 'sack of grain'] | |
}] | |
}, { | |
'if': 'far bank', | |
'contains': ['you', 'sack of grain'], | |
'then': [{ | |
'modify': 'far bank', | |
'remove': ['you', 'sack of grain'] | |
}, { | |
'modify': 'near bank', | |
'add': ['you', 'sack of grain'] | |
}] | |
}], | |
'cross alone': [{ | |
'if': 'near bank', | |
'contains': ['you'], | |
'then': [{ | |
'modify': 'near bank', | |
'remove': ['you'] | |
}, { | |
'modify': 'far bank', | |
'add': ['you'] | |
}] | |
}, { | |
'if': 'far bank', | |
'contains': ['you'], | |
'then': [{ | |
'modify': 'far bank', | |
'remove': ['you'] | |
}, { | |
'modify': 'near bank', | |
'add': ['you'] | |
}] | |
}] | |
}, | |
'consequences': { | |
'fox eats chicken': [{ | |
'if': 'near bank', | |
'contains': ['fox', 'chicken'], | |
'not_contains': ['you'], | |
'then': [{ | |
'modify': 'near bank', | |
'remove': ['chicken'] | |
}] | |
}, { | |
'if': 'far bank', | |
'contains': ['fox', 'chicken'], | |
'not_contains': ['you'], | |
'then': [{ | |
'modify': 'far bank', | |
'remove': ['chicken'] | |
}] | |
}], | |
'chicken eats grain': [{ | |
'if': 'near bank', | |
'contains': ['chicken', 'sack of grain'], | |
'not_contains': ['you'], | |
'then': [{ | |
'modify': 'near bank', | |
'remove': ['sack of grain'] | |
}] | |
}, { | |
'if': 'far bank', | |
'contains': ['chicken', 'sack of grain'], | |
'not_contains': ['you'], | |
'then': [{ | |
'modify': 'far bank', | |
'remove': ['sack of grain'] | |
}] | |
}] | |
} | |
} | |
class explored_move: | |
initial_wins = 1 | |
initial_plays = 2 | |
def __init__(self, num_actions, prev_move=None): | |
self.num_actions = num_actions | |
self.prev_move = prev_move | |
self.next_moves = {} | |
self.wins = explored_move.initial_wins | |
self.plays = explored_move.initial_plays | |
# play this move | |
def play(self): | |
self.plays += 1 | |
# mark move as illegal (don't try again) | |
def illegal(self): | |
self.wins = 0 | |
# this is a win - propagate up to root | |
def win(self): | |
self.wins += 1 | |
if self.prev_move is not None: | |
self.prev_move.win() | |
# get number of wins for next move i | |
def num_wins(self, i): | |
return self.next_moves[i].wins if i in self.next_moves else explored_move.initial_wins | |
# get number of plays for next move i | |
def num_plays(self, i): | |
return self.next_moves[i].plays if i in self.next_moves else explored_move.initial_plays | |
# create a move following the this one | |
def play_next(self, i): | |
if i not in self.next_moves: | |
self.next_moves[i] = explored_move(self.num_actions, self) | |
next_move = self.next_moves[i] | |
next_move.play() | |
return next_move | |
# choose move weighted by previous wins | |
def choose_next_move(self): | |
# calculate a common denominator | |
d = 1 | |
for m in range(self.num_actions): | |
if self.num_wins(m) > 0: | |
d *= self.num_plays(m) | |
# accumulate win ratio * common denominator | |
max_i = 0 | |
for m in range(self.num_actions): | |
max_i += self.num_wins(m) * (d // self.num_plays(m)) | |
# if all moves have zero wins then return -1 (no valid move) | |
if max_i == 0: | |
return -1 | |
# select a random point in the weighted move probability space | |
j = random.randrange(0, max_i) | |
# accumulate the move intervals to determine the selected move m for index i | |
i = 0 | |
for m in range(self.num_actions): | |
i += self.num_wins(m) * (d // self.num_plays(m)) | |
if i > j: | |
return m | |
# should never get here unless the above maths went wrong! | |
raise Exception('Unreachable: ' + str(i) + ' > ' + str(max_i)) | |
# return the best of the explored next moves | |
def best_move(self): | |
# calculate a common denominator | |
d = 1 | |
for i in range(self.num_actions): | |
if self.num_wins(i) > 0: | |
d *= self.num_plays(i) | |
# start with ratio zero and no valid move | |
best_ratio = 0 | |
best_m = (-1, None) | |
# for each move if the ratio is better store it as the new best along with the best move | |
for item in self.next_moves.items(): | |
m, node = item | |
ratio = node.wins * (d // node.plays) | |
if ratio > best_ratio: | |
best_ratio = ratio | |
best_m = (m, node) | |
# return the best move | |
return best_m | |
def apply_action(state, action): | |
for a in action: | |
# check for match | |
match = True | |
if 'contains' in a: | |
for c in a['contains']: | |
if c not in state[a['if']]: | |
match = False | |
break | |
if 'not_contains' in a: | |
for n in a['not_contains']: | |
if n in state[a['if']]: | |
match = False | |
break | |
if not match: | |
continue | |
# apply modifications | |
for m in a['then']: | |
if 'remove' in m: | |
for x in m['remove']: | |
state[m['modify']].remove(x) | |
elif 'add' in m: | |
for x in m['add']: | |
state[m['modify']].add(x) | |
return True | |
return False | |
def copy_state(state): | |
new_state = {} | |
for name, contents in state.items(): | |
new_state[name] = set(contents) | |
return new_state | |
def equal_state(left, right): | |
keys = set() | |
keys.update(list(left.keys()) + list(right.keys())) | |
for k in keys: | |
if k not in right.keys() or k not in left.keys() or left[k] != right[k]: | |
return False | |
return True | |
def try_play_move(actions, consequences, cur_move, state): | |
# try to select a move | |
i = cur_move.choose_next_move() | |
# if no valid move return Nones | |
if i == -1: | |
return None, None | |
# make the move | |
next_move = cur_move.play_next(i) | |
next_state = copy_state(state) | |
if not apply_action(next_state, actions[i]): | |
next_move.illegal() | |
return cur_move, state | |
# apply consequences | |
for c in consequences: | |
apply_action(next_state, c) | |
return next_move, next_state | |
# play with rules to reach output from input | |
def play(spec, paths, max_depth=1000): | |
actions = [a for _,a in spec['actions'].items()] | |
consequences = [c for _,c in spec['consequences'].items()] | |
start_state = spec['start_state'] | |
win_state = spec['win_state'] | |
# start with root move | |
root = explored_move(len(actions)) | |
# explore random paths | |
for p in range(paths): | |
# reset state for new path | |
state = start_state | |
cur_move = root | |
# make depth number of moves | |
for d in range(max_depth): | |
# play a random move | |
next_move, next_state = try_play_move(actions, consequences, cur_move, state) | |
# if no legal move then stop this path | |
if next_move is None: | |
break | |
# illegal move | |
if next_move == cur_move: | |
continue | |
# check if matches any win state | |
if equal_state(next_state, win_state): | |
next_move.win() | |
break | |
# move on to next move | |
cur_move = next_move | |
state = next_state | |
# replay move sequence with most wins | |
route = [] | |
i, node = root.best_move() | |
while node is not None and node.wins > 1: | |
route.append(i) | |
i, node = node.best_move() | |
# return the route | |
return route | |
# go! | |
route = play(problem_spec, 500) | |
print(route) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment