Skip to content

Instantly share code, notes, and snippets.

@Kaixhin
Last active May 23, 2018 22:53
Show Gist options
  • Save Kaixhin/0ecbd3f7a86adf55331f9fd21ed24257 to your computer and use it in GitHub Desktop.
Save Kaixhin/0ecbd3f7a86adf55331f9fd21ed24257 to your computer and use it in GitHub Desktop.
Introduction to Monte Carlo Tree Search
"""
Introduction to Monte Carlo Tree Search
http://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/
"""
from copy import deepcopy
import datetime
from math import log, sqrt
from random import choice
# Tic-tac-toe board
class Board():
# Helper for dealing with board state
def _flatten(self, state):
return sum(state, [])
# Helper for dealing with board state
def _unflatten(self, unrolled_state):
state = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
for y in range(3):
for x in range(3):
state[y][x] = unrolled_state[3 * y + x]
return state
# Returns the starting state of the game
def reset(self):
return [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
# Finds the current player's number
def current_player(self, state):
unrolled_state = self._flatten(state)
p1_count = sum(e == 1 for e in unrolled_state)
p2_count = sum(e == 2 for e in unrolled_state)
return 1 if p1_count == p2_count else 2
# Makes a step in the environment
def step(self, state, action):
current = self.current_player(state)
new_state = self._flatten(state) # Performs copy of state
new_state[action] = current
return self._unflatten(new_state)
# Returns the list of legal moves for the current player
def legal_actions(self, state):
current = self.current_player(state)
return [i for i, e in zip(range(9), self._flatten(state)) if e == 0]
# Returns winning player's number (1/2), 0 if ongoing, or -1 for draw
def winner(self, state):
# Check winning states
if (state[0][0] == 1 or state[0][0] == 2) and \
((state[0][0] == state[0][1] and state[0][1] == state[0][2]) or
(state[0][0] == state[1][1] and state[1][1] == state[2][2]) or
(state[0][0] == state[1][0] and state[1][0] == state[2][0])):
return state[0][0]
elif (state[0][1] == 1 or state[0][1] == 2) and \
state[0][1] == state[1][1] and state[1][1] == state[2][1]:
return state[0][1]
elif (state[0][2] == 1 or state[0][2] == 2) and \
((state[0][2] == state[1][1] and state[1][1] == state[2][0]) or
(state[0][2] == state[1][2] and state[1][2] == state[2][2])):
return state[0][2]
elif (state[1][0] == 1 or state[1][0] == 2) and \
state[1][0] == state[1][1] and state[1][1] == state[1][2]:
return state[1][0]
elif (state[2][0] == 1 or state[2][0] == 2) and \
state[2][0] == state[2][1] and state[2][1] == state[2][2]:
return state[2][0]
elif any(s == 0 for s in self._flatten(state)):
return 0
else:
return -1 # Assume draws only happen at end of game
# Returns a unique hash per unique state
def hash(self, state):
return ''.join(str(e) for e in self._flatten(state))
# Converts state element encoding for pretty printing
def _print_element(self, element):
if element == 1:
return 'X'
elif element == 2:
return 'O'
else:
return ' '
# Pretty prints a state
def pretty_print(self, state):
pretty_state = map(self._print_element, (state[0][0], state[0][1], state[0][2], state[1][0], state[1][1], state[1][2], state[2][0], state[2][1], state[2][2]))
print('Board:\n-----\n|%s%s%s|\n|%s%s%s|\n|%s%s%s|\n-----' % tuple(pretty_state))
# MCTS planner
class MCTS():
# Initializes the game history (states only) and the statistics tables
def __init__(self, board, **kwargs):
self.board = board
self.history = []
self.c = sqrt(2) # Exploration parameter
self.wins = {}
self.plays = {}
self.calculation_time = datetime.timedelta(seconds=kwargs.get('time', 10)) # Max amount of time per move calculation
self.max_moves = kwargs.get('max_moves', 100) # Max number of moves per rollout
# Appends a game state to the history
def update(self, state):
self.history.append(state)
# Calculate and return the best move from the current game state
def get_action(self):
self.max_depth = 0
state = self.history[-1]
player = self.board.current_player(state)
legal = self.board.legal_actions(state)
# Stop early if no choice to be made
if len(legal) == 0:
return
elif len(legal) == 1:
return legal[0]
# Run simulations repeatedly until set time elapsed
games = 0
begin = datetime.datetime.utcnow()
while datetime.datetime.utcnow() - begin < self.calculation_time:
self.run_simulation()
games += 1
# Store state-action pairs (that are hashable)
states_actions = [(self.board.hash(self.board.step(state, a)), a) for a in legal]
# Display the number of calls to run_simulation and the time elapsed
print('Current Player:', player, '| Simulated Games:', games, '| Search Time:', datetime.datetime.utcnow() - begin)
# Pick the action with the highest percentage of wins
percent_wins, action = max(
(self.wins.get((player, s), 0) / self.plays.get((player, s), 1), a)
for s, a in states_actions)
# Display the stats for each possible play
print('Action: Win Rate (Wins / Plays)')
for x in sorted(
((100 * self.wins.get((player, s), 0) / self.plays.get((player, s), 1),
self.wins.get((player, s), 0), self.plays.get((player, s), 0), a)
for s, a in states_actions),
reverse=True):
print("{3}: {0:.2f}% ({1} / {2})".format(*x))
print("Maximum Search Depth:", self.max_depth)
return action
# Plays out a pseudorandom game from the current position and updates the statistics tables
def run_simulation(self):
visited_states = set()
history_copy = deepcopy(self.history) # Keep separate copy of canonical game tree
state = history_copy[-1]
player = self.board.current_player(state)
hashable_state = self.board.hash(state)
expand = True
for t in range(1, self.max_moves + 1):
legal = self.board.legal_actions(state)
# Store state-action pairs (that are hashable)
states_actions = [(self.board.step(state, a), a) for a in legal]
if all(self.plays.get((player, self.board.hash(s))) for s, a in states_actions):
# If we have stats on all of the legal moves, use Upper Confidence Bound 1 applied to trees (UCT) to choose the next action
log_total = log(sum(self.plays[(player, self.board.hash(s))] for s, a in states_actions))
value, action, state = max(
((self.wins[(player, self.board.hash(s))] / self.plays[(player, self.board.hash(s))]) + self.c * sqrt(log_total / self.plays[(player, self.board.hash(s))]), a, s)
for s, a in states_actions)
else:
# Otherwise choose a random move
state, action = choice(states_actions)
history_copy.append(state)
hashable_state = self.board.hash(state)
# Initialise stats for moving player (if necessary)
if expand and (player, hashable_state) not in self.plays:
expand = False
self.plays[(player, hashable_state)] = 0
self.wins[(player, hashable_state)] = 0
self.max_depth = max(t, self.max_depth)
visited_states.add((player, hashable_state))
player = self.board.current_player(state)
winner = self.board.winner(state)
if winner:
break
for v_player, v_state in visited_states: # Contains hashable states
if (v_player, v_state) not in self.plays:
continue
self.plays[(v_player, v_state)] += 1
if v_player == winner:
self.wins[(v_player, v_state)] += 1
# Play one game
if __name__ == '__main__':
env = Board()
mcts = MCTS(Board())
state, winner = env.reset(), 0
env.pretty_print(state)
while not winner:
mcts.update(state)
action = mcts.get_action()
state = env.step(state, action)
env.pretty_print(state)
winner = env.winner(state)
print('Winner: Player', winner)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment