This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def muzero(config: MuZeroConfig): | |
storage = SharedStorage() | |
replay_buffer = ReplayBuffer(config) | |
for _ in range(config.num_actors): | |
launch_job(run_selfplay, config, storage, replay_buffer) | |
train_network(config, storage, replay_buffer) | |
return storage.latest_network() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MuZeroConfig(object): | |
def __init__(self, | |
action_space_size: int, | |
max_moves: int, | |
discount: float, | |
dirichlet_alpha: float, | |
num_simulations: int, | |
batch_size: int, | |
td_steps: int, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SharedStorage(object): | |
def __init__(self): | |
self._networks = {} | |
def latest_network(self) -> Network: | |
if self._networks: | |
return self._networks[max(self._networks.keys())] | |
else: | |
# policy -> uniform, value -> 0, reward -> 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ReplayBuffer(object): | |
def __init__(self, config: MuZeroConfig): | |
self.window_size = config.window_size | |
self.batch_size = config.batch_size | |
self.buffer = [] | |
def save_game(self, game): | |
if len(self.buffer) > self.window_size: | |
self.buffer.pop(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Each self-play job is independent of all others; it takes the latest network | |
# snapshot, produces a game and makes it available to the training job by | |
# writing it to a shared replay buffer. | |
def run_selfplay(config: MuZeroConfig, storage: SharedStorage, | |
replay_buffer: ReplayBuffer): | |
while True: | |
network = storage.latest_network() | |
game = play_game(config, network) | |
replay_buffer.save_game(game) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_network(config: MuZeroConfig, storage: SharedStorage, | |
replay_buffer: ReplayBuffer): | |
network = Network() | |
learning_rate = config.lr_init * config.lr_decay_rate**( | |
tf.train.get_global_step() / config.lr_decay_steps) | |
optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum) | |
for i in range(config.training_steps): | |
if i % config.checkpoint_interval == 0: | |
storage.save_network(i, network) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Each game is produced by starting at the initial board position, then | |
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end | |
# of the game is reached. | |
def play_game(config: MuZeroConfig, network: Network) -> Game: | |
game = config.new_game() | |
while not game.terminal() and len(game.history) < config.max_moves: | |
# At the root of the search tree we use the representation function to | |
# obtain a hidden state given the current observation. | |
root = Node(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node(object): | |
def __init__(self, prior: float): | |
self.visit_count = 0 | |
self.to_play = -1 | |
self.prior = prior | |
self.value_sum = 0 | |
self.children = {} | |
self.hidden_state = None | |
self.reward = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NetworkOutput(typing.NamedTuple): | |
value: float | |
reward: float | |
policy_logits: Dict[Action, float] | |
hidden_state: List[float] | |
class Network(object): | |
def initial_inference(self, image) -> NetworkOutput: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We expand a node using the value, reward and policy prediction obtained from | |
# the neural network. | |
def expand_node(node: Node, to_play: Player, actions: List[Action], | |
network_output: NetworkOutput): | |
node.to_play = to_play | |
node.hidden_state = network_output.hidden_state | |
node.reward = network_output.reward | |
policy = {a: math.exp(network_output.policy_logits[a]) for a in actions} | |
policy_sum = sum(policy.values()) | |
for action, p in policy.items(): |
OlderNewer