Skip to content

Instantly share code, notes, and snippets.

Created January 24, 2018 00:21
Show Gist options
  • Save bshishov/b3158c327ae7de170793f10a67903ad0 to your computer and use it in GitHub Desktop.
Save bshishov/b3158c327ae7de170793f10a67903ad0 to your computer and use it in GitHub Desktop.
Monte Carlo Tree Search
import numpy as np
from typing import List, Tuple
from alpha_chess_zero import settings
from alpha_chess_zero.game_model import GameModel
class MCTSNode(object):
def __init__(self, state):
self.state = state
self.children = []
self.actions = []
# Placeholders for numpy arrays
self.edge_q = None
self.edge_w = None
self.edge_p = None
self.edge_n = None
def select(self,
game_model: GameModel,
path_nodes: List['MCTSNode'],
path_edge_indices: List[int]) -> Tuple['MCTSNode', List['MCTSNode'], List[int]]:
# If current node does not have any actions
# then select finishes - we found the leaf node
if len(self.actions) == 0:
return self, path_nodes, path_edge_indices
# Otherwise, walk the tree by selecting actions with max Q + U
# Find the child node and action with max Q + U
n_sqrt_sum = np.sqrt(np.sum(self.edge_n))
n_sqrt_sum = np.maximum(n_sqrt_sum, 1.0) # Avoid U == 0, if all N(s,b) == 0
# Add Dirichlet noise to move probabilities
p = (1.0 - settings.MCTS_DIR_EPSILON) * self.edge_p + \
settings.MCTS_DIR_EPSILON * np.random.dirichlet([settings.MCTS_DIR_ALPHA] * len(self.edge_p))
u = settings.MCTS_C_PUCT * p * n_sqrt_sum / (1.0 + self.edge_n)
selected = np.argmax(self.edge_q + u) # type: int
# If we don't have the state for the successor node yet
# Then get the state by taking action from this state
# This "lazy loading" of the child node for argmax(Q + U)
# and prevents evaluating each successor
if self.children[selected] is None:
child_state = game_model.get_state_for_action(self.state, self.actions[selected])
child_node = MCTSNode(child_state)
self.children[selected] = child_node
# Continue search from the child node
# by calling select of the child node
# Return the full path (both nodes and edges) of select
return self.children[selected].select(
path_nodes=path_nodes + [self, ],
path_edge_indices=path_edge_indices + [selected, ])
def expand(self, game_model: GameModel, estimator):
# Get actions that we can take from the current state
# Note: next states will be evaluated on-demand in select
# to prevent evaluating nodes with low probabilities
# (which will no likely to be used)
self.actions = game_model.get_actions_for_state(self.state)
# Predict with any kind of estimator the probabilities
# over action given current state
self.edge_p, value = estimator.predict(self.state, self.actions)
# Add child edges and node-placeholders
action_num = len(self.actions)
self.edge_q = np.zeros(action_num, dtype=np.float32)
self.edge_w = np.zeros(action_num, dtype=np.float32)
self.edge_n = np.zeros(action_num, dtype=np.uint8)
self.children = np.full(action_num, None, dtype=np.object)
# Return the value of the expanded node
# To update parent nodes and edges along the selected path
return value
def update_edge(self, edge_idx: int, value: float):
# Updates the value of the edge
# Incrementing number of visits (n)
self.edge_n[edge_idx] += 1
# Adding value to the w (total value of the children nodes)
self.edge_w[edge_idx] += value
# Updating q to average value of the children nodes
self.edge_q[edge_idx] = self.edge_w[edge_idx] / self.edge_n[edge_idx]
def choose_action(self, tau=1.0, deterministic=False) -> Tuple[list, 'MCTSNode', np.ndarray]:
# Select an action with some policy from the current state
# And return chosen action and next state
# Deterministic policy: select edge with max visits
if deterministic:
idx = np.argmax(self.edge_n) # type: int
probabilities = np.zeros(len(self.actions))
probabilities[np.argmax(self.edge_n)] = 1.0
return self.actions[idx], self.children[idx], probabilities
# Probabilistic policy: select edge with weights of N^(1/tau)
exp_n = np.power(self.edge_n, 1.0 / tau)
probabilities = exp_n / exp_n.sum()
idx = np.random.choice(np.arange(len(self.actions)), p=probabilities)
return self.actions[idx], self.children[idx], probabilities
def run(self, game_model: GameModel, estimator, simulations: int):
# Run select, expand and propagate multiple times
for sim_i in range(simulations):
# Select the node to expand, and get the path edges
node, path_nodes, path_edges =, [], [])
# Expand the node and get expanded node value (v)
value = node.expand(game_model, estimator)
# Update N (visits), W (total value), and Q (average value)
# for the whole path
for node, edge_idx in zip(path_nodes, path_edges):
node.update_edge(edge_idx, value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment