Skip to content

Instantly share code, notes, and snippets.

View davidADSP's full-sized avatar

David Foster davidADSP

View GitHub Profile
def muzero(config: MuZeroConfig):
storage = SharedStorage()
replay_buffer = ReplayBuffer(config)
for _ in range(config.num_actors):
launch_job(run_selfplay, config, storage, replay_buffer)
train_network(config, storage, replay_buffer)
return storage.latest_network()
class MuZeroConfig(object):
def __init__(self,
action_space_size: int,
max_moves: int,
discount: float,
dirichlet_alpha: float,
num_simulations: int,
batch_size: int,
td_steps: int,
class SharedStorage(object):
def __init__(self):
self._networks = {}
def latest_network(self) -> Network:
if self._networks:
return self._networks[max(self._networks.keys())]
# policy -> uniform, value -> 0, reward -> 0
class ReplayBuffer(object):
def __init__(self, config: MuZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game):
if len(self.buffer) > self.window_size:
# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: MuZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
while True:
network = storage.latest_network()
game = play_game(config, network)
def train_network(config: MuZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
network = Network()
learning_rate = config.lr_init * config.lr_decay_rate**(
tf.train.get_global_step() / config.lr_decay_steps)
optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum)
for i in range(config.training_steps):
if i % config.checkpoint_interval == 0:
storage.save_network(i, network)
# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: MuZeroConfig, network: Network) -> Game:
game = config.new_game()
while not game.terminal() and len(game.history) < config.max_moves:
# At the root of the search tree we use the representation function to
# obtain a hidden state given the current observation.
root = Node(0)
class Node(object):
def __init__(self, prior: float):
self.visit_count = 0
self.to_play = -1
self.prior = prior
self.value_sum = 0
self.children = {}
self.hidden_state = None
self.reward = 0
class NetworkOutput(typing.NamedTuple):
value: float
reward: float
policy_logits: Dict[Action, float]
hidden_state: List[float]
class Network(object):
def initial_inference(self, image) -> NetworkOutput:
# We expand a node using the value, reward and policy prediction obtained from
# the neural network.
def expand_node(node: Node, to_play: Player, actions: List[Action],
network_output: NetworkOutput):
node.to_play = to_play
node.hidden_state = network_output.hidden_state
node.reward = network_output.reward
policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
policy_sum = sum(policy.values())
for action, p in policy.items():