This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The score for a node is based on its value, plus an exploration bonus based on | |
# the prior. | |
def ucb_score(config: MuZeroConfig, parent: Node, child: Node, | |
min_max_stats: MinMaxStats) -> float: | |
pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / | |
config.pb_c_base) + config.pb_c_init | |
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) | |
prior_score = pb_c * child.prior | |
value_score = min_max_stats.normalize(child.value()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Select the child with the highest UCB score. | |
def select_child(config: MuZeroConfig, node: Node, | |
min_max_stats: MinMaxStats): | |
_, action, child = max( | |
(ucb_score(config, node, child, min_max_stats), action, | |
child) for action, child in node.children.items()) | |
return action, child |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Core Monte Carlo Tree Search algorithm. | |
# To decide on an action, we run N simulations, always starting at the root of | |
# the search tree and traversing the tree according to the UCB formula until we | |
# reach a leaf node. | |
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory, | |
network: Network): | |
min_max_stats = MinMaxStats(config.known_bounds) | |
for _ in range(config.num_simulations): | |
history = action_history.clone() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# At the start of each search, we add dirichlet noise to the prior of the root | |
# to encourage the search to explore new actions. | |
def add_exploration_noise(config: MuZeroConfig, node: Node): | |
actions = list(node.children.keys()) | |
noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions)) | |
frac = config.root_exploration_fraction | |
for a, n in zip(actions, noise): | |
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We expand a node using the value, reward and policy prediction obtained from | |
# the neural network. | |
def expand_node(node: Node, to_play: Player, actions: List[Action], | |
network_output: NetworkOutput): | |
node.to_play = to_play | |
node.hidden_state = network_output.hidden_state | |
node.reward = network_output.reward | |
policy = {a: math.exp(network_output.policy_logits[a]) for a in actions} | |
policy_sum = sum(policy.values()) | |
for action, p in policy.items(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NetworkOutput(typing.NamedTuple): | |
value: float | |
reward: float | |
policy_logits: Dict[Action, float] | |
hidden_state: List[float] | |
class Network(object): | |
def initial_inference(self, image) -> NetworkOutput: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node(object): | |
def __init__(self, prior: float): | |
self.visit_count = 0 | |
self.to_play = -1 | |
self.prior = prior | |
self.value_sum = 0 | |
self.children = {} | |
self.hidden_state = None | |
self.reward = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Each game is produced by starting at the initial board position, then | |
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end | |
# of the game is reached. | |
def play_game(config: MuZeroConfig, network: Network) -> Game: | |
game = config.new_game() | |
while not game.terminal() and len(game.history) < config.max_moves: | |
# At the root of the search tree we use the representation function to | |
# obtain a hidden state given the current observation. | |
root = Node(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_network(config: MuZeroConfig, storage: SharedStorage, | |
replay_buffer: ReplayBuffer): | |
network = Network() | |
learning_rate = config.lr_init * config.lr_decay_rate**( | |
tf.train.get_global_step() / config.lr_decay_steps) | |
optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum) | |
for i in range(config.training_steps): | |
if i % config.checkpoint_interval == 0: | |
storage.save_network(i, network) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Each self-play job is independent of all others; it takes the latest network | |
# snapshot, produces a game and makes it available to the training job by | |
# writing it to a shared replay buffer. | |
def run_selfplay(config: MuZeroConfig, storage: SharedStorage, | |
replay_buffer: ReplayBuffer): | |
while True: | |
network = storage.latest_network() | |
game = play_game(config, network) | |
replay_buffer.save_game(game) |