Last active
July 14, 2024 11:01
-
-
Save Mononofu/7548d8aa4bf94e12bc7eb7662fd60b56 to your computer and use it in GitHub Desktop.
Pseudocode for Stochastic MuZero
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 DeepMind Technologies Limited. | |
# Licensed under the Apache License, Version 2.0 and CC BY 4.0. | |
# You may not use this file except in compliance with these licenses. | |
# Copies of the licenses can be found at https://www.apache.org/licenses/LICENSE-2.0 | |
# and https://creativecommons.org/licenses/by/4.0/legalcode. | |
"""Pseudocode description of the Stochastic MuZero algorithm. | |
This pseudocode was adapted from the original MuZero pseudocode. | |
""" | |
# pylint: disable=unused-argument | |
# pylint: disable=missing-docstring | |
# pylint: disable=g-explicit-length-test | |
import abc | |
import math | |
from typing import Any, Dict, Callable, List, NamedTuple, Tuple, Union, Optional, Sequence | |
import dataclasses | |
import numpy as np | |
MAXIMUM_FLOAT_VALUE = float('inf') | |
######################################## | |
####### Environment interface ########## | |
# An action to apply to the environment. | |
# It can a single integer or a list of micro-actions for backgammon. | |
Action = Any | |
# The current player to play. | |
Player = int | |
class Environment: | |
"""Implements the rules of the environment.""" | |
def apply(self, action: Action): | |
"""Applies an action or a chance outcome to the environment.""" | |
def observation(self): | |
"""Returns the observation of the environment to feed to the network.""" | |
def is_terminal(self) -> bool: | |
"""Returns true if the environment is in a terminal state.""" | |
return False | |
def legal_actions(self) -> Sequence[Action]: | |
"""Returns the legal actions for the current state.""" | |
return [] | |
def reward(self, player: Player) -> float: | |
"""Returns the last reward for the player.""" | |
return 0.0 | |
def to_play(self) -> Player: | |
"""Returns the current player to play.""" | |
return 0 | |
########################## | |
####### Helpers ########## | |
class KnownBounds(NamedTuple): | |
min: float | |
max: float | |
class MinMaxStats(object): | |
"""A class that holds the min-max values of the tree.""" | |
def __init__(self, known_bounds: Optional[KnownBounds]): | |
self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE | |
self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE | |
def update(self, value: float): | |
self.maximum = max(self.maximum, value) | |
self.minimum = min(self.minimum, value) | |
def normalize(self, value: float) -> float: | |
if self.maximum > self.minimum: | |
# We normalize only when we have set the maximum and minimum values. | |
return (value - self.minimum) / (self.maximum - self.minimum) | |
return value | |
# A chance outcome. | |
Outcome = Any | |
# An object that holds an action or a chance outcome. | |
ActionOrOutcome = Union[Action, Outcome] | |
LatentState = List[float] | |
AfterState = List[float] | |
class NetworkOutput(NamedTuple): | |
value: float | |
probabilities: Dict[ActionOrOutcome, float] | |
reward: Optional[float] = 0.0 | |
class Network: | |
"""An instance of the network used by stochastic MuZero.""" | |
def representation(self, observation) -> LatentState: | |
"""Representation function maps from observation to latent state.""" | |
return [] | |
def predictions(self, state: LatentState) -> NetworkOutput: | |
"""Returns the network predictions for a latent state.""" | |
return NetworkOutput(0, {}, 0) | |
def afterstate_dynamics(self, | |
state: LatentState, | |
action: Action) -> AfterState: | |
"""Implements the dynamics from latent state and action to afterstate.""" | |
return [] | |
def afterstate_predictions(self, state: AfterState) -> NetworkOutput: | |
"""Returns the network predictions for an afterstate.""" | |
# No reward for afterstate transitions. | |
return NetworkOutput(0, {}) | |
def dynamics(self, state: AfterState, action: Outcome) -> LatentState: | |
"""Implements the dynamics from afterstate and chance outcome to state.""" | |
return [] | |
def encoder(self, observation) -> Outcome: | |
"""An encoder maps an observation to an outcome.""" | |
class NFSPAveragePolicyNetwork: | |
"""An instance of the average policy network for NFP.""" | |
def forward(self, observation) -> Dict[Action, float]: | |
"""Returns the average policy of NFSP.""" | |
class NetworkCacher: | |
"""An object to share the network between the self-play and training jobs.""" | |
def __init__(self): | |
self._networks = {} | |
def save_network(self, step: int, network: Network): | |
self._networks[step] = network | |
def load_network(self) -> Tuple[int, Network]: | |
training_step = max(self._networks.keys()) | |
return training_step, self._networks[training_step] | |
# Takes the training step and returns the temperature of the softmax policy. | |
VisitSoftmaxTemperatureFn = Callable[[int], float] | |
# Returns an instance of the environment. | |
EnvironmentFactory = Callable[[], Environment] | |
# The factory for the network. | |
NetworkFactory = Callable[[], Network] | |
# The factory for the NFSP Network. | |
NFSPNetworkFactory = Callable[[], NFSPAveragePolicyNetwork] | |
@dataclasses.dataclass | |
class StochasticMuZeroConfig: | |
# A factory for the environment. | |
environment_factory: EnvironmentFactory | |
network_factory: NetworkFactory | |
# Self-Play | |
num_actors: int | |
visit_softmax_temperature_fn: VisitSoftmaxTemperatureFn | |
num_simulations: int | |
discount: float | |
# Root prior exploration noise. | |
root_dirichlet_alpha: float | |
root_dirichlet_fraction: float | |
root_dirichlet_adaptive: bool | |
# UCB formula | |
pb_c_base: float = 19652 | |
pb_c_init: float = 1.25 | |
# If we already have some information about which values occur in the | |
# environment, we can use them to initialize the rescaling. | |
# This is not strictly necessary, but establishes identical behaviour to | |
# AlphaZero in board games. | |
known_bounds: Optional[KnownBounds] = None | |
# Replay buffer. | |
num_trajectories_in_buffer: int = int(1e6) | |
batch_size: int = int(128) | |
num_unroll_steps: int = 5 | |
td_steps: int = 6 | |
td_lambda: float = 1.0 | |
# Alpha and beta parameters for prioritization. | |
# By default they are set to 0 which means uniform sampling. | |
priority_alpha: float = 0.0 | |
priority_beta: float = 0.0 | |
# Reservoir sampling to be used only for NFSP. | |
revervoir_replay_size: int = -1 | |
# A factor to decide the ratio between the average policy acting | |
# and the best response one (Stochastic MuZero). | |
# See https://arxiv.org/abs/1603.01121 for more details. | |
anticipatory_factor: float = 0.1 | |
# A network factory for the average policy in NFSP. | |
nfsp_network_factory: Optional[NFSPNetworkFactory] = None | |
# The learning rate for the NFSP average policy. | |
nfsp_learning_rate: float = 5e-3 | |
# Training | |
training_steps: int = int(1e6) | |
export_network_every: int = int(1e3) | |
learning_rate: float = 3e-4 | |
weight_decay: float = 1e-4 | |
# The number of chance codes (codebook size). | |
# We use a codebook of size 32 for all our experiments. | |
codebook_size: int = 32 | |
################################## | |
## Environment specific configs ## | |
def twentyfortyeight_config() -> StochasticMuZeroConfig: | |
"""Returns the config for the game of 2048.""" | |
def environment_factory(): | |
# Returns an implementation of 2048. | |
return Environment() | |
def network_factory(): | |
# 10 layer fully connected Res V2 network with Layer normalization and size | |
# 256. | |
return Network() | |
def visit_softmax_temperature(train_steps: int) -> float: | |
if train_steps < 1e5: | |
return 1.0 | |
elif train_steps < 2e5: | |
return 0.5 | |
elif train_steps < 3e5: | |
return 0.1 | |
else: | |
# Greedy selection. | |
return 0.0 | |
return StochasticMuZeroConfig( | |
environment_factory=environment_factory, | |
network_factory=network_factory, | |
num_actors=1000, | |
visit_softmax_temperature=visit_softmax_temperature, | |
num_simulations=100, | |
discount=0.999, | |
root_dirichlet_alpha=0.3, | |
root_dirichlet_fraction=0.1, | |
root_dirichlet_adaptive=False, | |
num_trajectories_in_buffer=int(125e3), | |
td_steps=10, | |
td_lambda=0.5, | |
priority_alpha=1.0, | |
priority_beta=1.0, | |
training_steps=int(20e6), | |
batch_size=1024, | |
weight_decay=0.0) | |
def backgammon_config() -> StochasticMuZeroConfig: | |
"""Returns the config for the game of 2048.""" | |
def environment_factory(): | |
# Returns an backgammon. We consider single games without a doubling cube. | |
return Environment() | |
def network_factory(): | |
# 10 layer fully connected Res V2 network with Layer normalization and size | |
# 256. | |
return Network() | |
def visit_softmax_temperature(train_steps: int) -> float: | |
return 1.0 | |
return StochasticMuZeroConfig( | |
environment_factory=environment_factory, | |
network_factory=network_factory, | |
num_actors=1000, | |
visit_softmax_temperature_fn=visit_softmax_temperature, | |
num_simulations=1600, | |
discount=1.0, | |
# Unused, we use adaptive dirichlet for backgammon. | |
root_dirichlet_alpha=-1.0, | |
root_dirichlet_fraction=0.1, | |
root_dirichlet_adaptive=True, | |
# Max value is 3 for backgammon. | |
known_bounds=KnownBounds(min=-3, max=3), | |
# 1e5 full episodes stored. | |
num_trajectories_in_buffer=int(1e5), | |
# We use monte carlo returns. | |
td_steps=int(1e3), | |
training_steps=int(8e6), | |
batch_size=1024, | |
learning_rate=3e-4, | |
weight_decay=1e-4) | |
def leduc_poker_config() -> StochasticMuZeroConfig: | |
"""Returns the config for the game of 2048.""" | |
def environment_factory(): | |
# Returns an OpenSpiel implementation of Leduc poker. | |
return Environment() | |
def network_factory(): | |
# 2 layer fully connected network with size of 128. | |
return Network() | |
def nfsp_network_factory(): | |
# single layer network with a size of 256. | |
return NFSPAveragePolicyNetwork() | |
def visit_softmax_temperature(train_steps: int) -> float: | |
# Greedy policy throughout. | |
return 0.0 | |
return StochasticMuZeroConfig( | |
# A factory for the environment. | |
environment_factory=environment_factory, | |
network_factory=network_factory, | |
num_actors=1000, | |
visit_softmax_temperature=visit_softmax_temperature, | |
# We use 1600 simulations for leduc poker given the high | |
# number of possible chance outcomes. | |
num_simulations=1600, | |
discount=1.0, | |
root_dirichlet_alpha=0.25, | |
root_dirichlet_fraction=0.1, | |
root_dirichlet_adaptive=False, | |
num_trajectories_in_buffer=int(5e4), | |
revervoir_replay_size=int(2e6), | |
anticipatory_factor=0.1, | |
nfsp_network_factory=nfsp_network_factory, | |
nfsp_learning_rate=5e-3, | |
training_steps=int(3.5e6), | |
batch_size=256, | |
export_network_every=int(500), | |
learning_rate=3e-4, | |
weight_decay=0.0) | |
################################## | |
############ Replay ############## | |
class SearchStats(NamedTuple): | |
search_policy: Dict[Action, int] | |
search_value: float | |
class State(NamedTuple): | |
"""Data for a single state.""" | |
observation: List[float] | |
reward: float | |
discount: float | |
player: Player | |
action: Action | |
search_stats: SearchStats | |
# For NFSP trajectories we store the average policy. | |
# This can be used for V-Trace off-policy correction when training | |
# on targets generated by the nfsp average policy. | |
nfsp_average_policy: Optional[Dict[Action, float]] = None | |
Trajectory = Sequence[State] | |
class ReplayBuffer: | |
"""A replay buffer to hold the experience generated by the selfplay.""" | |
def __init__(self, config: StochasticMuZeroConfig): | |
self.config = config | |
self.data = [] | |
def save(self, seq: Trajectory): | |
if len(self.data) > self.config.num_trajectories_in_buffer: | |
# Remove the oldest sequence from the buffer. | |
self.data.pop(0) | |
self.data.append(seq) | |
def sample_trajectory(self) -> Trajectory: | |
"""Samples a trajectory uniformly or using prioritization.""" | |
return self.data[0] | |
def sample_index(self, seq: Trajectory) -> int: | |
"""Samples an index in the trajectory uniformly or using prioritization.""" | |
return 0 | |
def sample_element(self) -> Trajectory: | |
"""Samples a single element from the buffer.""" | |
# Sample a trajectory. | |
trajectory = self.sample_trajectory() | |
state_idx = self.sample_index(trajectory) | |
limit = max([self.config.num_unroll_steps, self.config.td_steps]) | |
# Returns a trajectory of experiment. | |
return trajectory[state_idx:state_idx + limit] | |
def sample(self) -> Sequence[Trajectory]: | |
"""Samples a training batch.""" | |
return [self.sample_element() for _ in range(self.config.batch_size)] | |
class ReservoirBuffer(object): | |
"""A replay buffer to hold the experience generated by the selfplay.""" | |
def __init__(self, config: StochasticMuZeroConfig): | |
self.config = config | |
self.added = 0 | |
self.data = [] | |
def save(self, seq: Trajectory): | |
for state in seq: | |
# Add each state in the reservoir buffer. | |
if len(self.data) < self.config.revervoir_replay_size: | |
self.data.append(state) | |
else: | |
if np.random.randint(0, self.added) < self.config.revervoir_replay_size: | |
index = np.random.randint(0, self.config.revervoir_replay_size) | |
self.data[index] = state | |
self.added += 1 | |
def sample(self) -> Sequence[State]: | |
"""Samples a training batch.""" | |
return [np.random.choice(self.data) for _ in range(self.config.batch_size)] | |
################################## | |
############ Search ############## | |
class ActionOutcomeHistory: | |
"""Simple history container used inside the search. | |
Only used to keep track of the actions and chance outcomes executed. | |
""" | |
def __init__(self, | |
player: Player, | |
history: Optional[List[ActionOrOutcome]] = None): | |
self.initial_player = player | |
self.history = list(history or []) | |
def clone(self): | |
return ActionOutcomeHistory(self.initial_player, self.history) | |
def add_action_or_outcome(self, action_or_outcome: ActionOrOutcome): | |
self.history.append(action_or_outcome) | |
def last_action_or_outcome(self) -> ActionOrOutcome: | |
return self.history[-1] | |
def to_play(self) -> Player: | |
# Returns the next player to play based on the initial player and the | |
# history of actions and outcomes. For example for backgammon the two | |
# players alternate, while for 2048 it is always the same player. | |
return 0 | |
class Node(object): | |
"""A Node in the MCTS search tree.""" | |
def __init__(self, | |
prior: float, | |
is_chance: bool = False): | |
self.visit_count = 0 | |
self.to_play = -1 | |
self.prior = prior | |
self.value_sum = 0 | |
self.children = {} | |
self.state = None | |
self.is_chance = is_chance | |
self.reward = 0 | |
def expanded(self) -> bool: | |
return len(self.children) > 0 | |
def value(self) -> float: | |
if self.visit_count == 0: | |
return 0 | |
return self.value_sum / self.visit_count | |
# Core Monte Carlo Tree Search algorithm. | |
# To decide on an action, we run N simulations, always starting at the root of | |
# the search tree and traversing the tree according to the UCB formula until we | |
# reach a leaf node. | |
def run_mcts(config: StochasticMuZeroConfig, root: Node, | |
action_outcome_history: ActionOutcomeHistory, network: Network, | |
min_max_stats: MinMaxStats): | |
for _ in range(config.num_simulations): | |
history = action_outcome_history.clone() | |
node = root | |
search_path = [node] | |
while node.expanded(): | |
action_or_outcome, node = select_child(config, node, min_max_stats) | |
history.add_action(action_or_outcome) | |
search_path.append(node) | |
# Inside the search tree we use the dynamics function to obtain the next | |
# hidden state given an action and the previous hidden state. | |
parent = search_path[-2] | |
if parent.is_chance: | |
# The parent is a chance node, afterstate to latent state transition. | |
# The last action or outcome is a chance outcome. | |
child_state = network_output.dynamics(parent.state, | |
history.last_action_or_outcome()) | |
network_output = network_output.predictions(child_state) | |
# This child is a decision node. | |
is_child_chance = False | |
else: | |
# The parent is a decision node, latent state to afterstate transition. | |
# The last action or outcome is an action. | |
child_state = network_output.afterstate_dynamics( | |
parent.state, history.last_action_or_outcome()) | |
network_output = network_output.afterstate_predictions(child_state) | |
# The child is a chance node. | |
is_child_chance = True | |
# Expand the node. | |
expand_node(node, child_state, network_output, history.to_play(), | |
is_child_chance) | |
# Backpropagate the value up the tree. | |
backpropagate(search_path, network_output.value, history.to_play(), | |
config.discount, min_max_stats) | |
# Select the child with the highest UCB score. | |
def select_child(config: StochasticMuZeroConfig, node: Node, | |
min_max_stats: MinMaxStats): | |
if node.is_chance: | |
# If the node is chance we sample from the prior. | |
outcomes, probs = zip(*[(o, n.prob) for o, n in node.children.items()]) | |
outcome = np.random.choice(outcomes, p=probs) | |
return outcome, node.children[outcome] | |
# For decision nodes we use the pUCT formula. | |
_, action, child = max( | |
(ucb_score(config, node, child, min_max_stats), action, child) | |
for action, child in node.children.items()) | |
return action, child | |
# The score for a node is based on its value, plus an exploration bonus based on | |
# the prior. | |
def ucb_score(config: StochasticMuZeroConfig, parent: Node, child: Node, | |
min_max_stats: MinMaxStats) -> float: | |
pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / | |
config.pb_c_base) + config.pb_c_init | |
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) | |
prior_score = pb_c * child.prior | |
if child.visit_count > 0: | |
value_score = min_max_stats.normalize(child.reward + | |
config.discount * child.value()) | |
else: | |
value_score = 0 | |
return prior_score + value_score | |
# We expand a node using the value, reward and policy prediction obtained from | |
# the neural network. | |
def expand_node(node: Node, state: Union[LatentState, AfterState], | |
network_output: NetworkOutput, player: Player, is_chance: bool): | |
node.to_play = player | |
node.state = state | |
node.is_chance = is_chance | |
node.reward = network_output.reward | |
for action, prob in network_output.probabilities.items(): | |
node.children[action] = Node(prob) | |
# At the end of a simulation, we propagate the evaluation all the way up the | |
# tree to the root. | |
def backpropagate(search_path: List[Node], value: float, to_play: Player, | |
discount: float, min_max_stats: MinMaxStats): | |
for node in reversed(search_path): | |
node.value_sum += value if node.to_play == to_play else -value | |
node.visit_count += 1 | |
min_max_stats.update(node.value()) | |
value = node.reward + discount * value | |
# At the start of each search, we add dirichlet noise to the prior of the root | |
# to encourage the search to explore new actions. | |
def add_exploration_noise(config: StochasticMuZeroConfig, node: Node): | |
actions = list(node.children.keys()) | |
dir_alpha = config.root_dirichlet_alpha | |
if config.root_dirichlet_adaptive: | |
dir_alpha = 1.0 / np.sqrt(len(actions)) | |
noise = np.random.dirichlet([dir_alpha] * len(actions)) | |
frac = config.root_exploration_fraction | |
for a, n in zip(actions, noise): | |
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac | |
################################## | |
############ Self-play ########### | |
class Actor(metaclass=abc.ABCMeta): | |
"""An actor to interact with the environment.""" | |
@abc.abstractmethod | |
def reset(self): | |
"""Resets the player for a new episode.""" | |
@abc.abstractmethod | |
def select_action(self, env: Environment) -> Action: | |
"""Selects an action for the current state of the environment.""" | |
@abc.abstractmethod | |
def stats(self) -> SearchStats: | |
"""Returns the stats for the player after it has selected an action.""" | |
class NFSPAveragePolicyActor(Actor): | |
"""An actor which uses the nfsp network.""" | |
def __init__(self, cacher: NetworkCacher): | |
self.cacher = cacher | |
self.network = None | |
def reset(self): | |
_, self.network = self.cacher.load_network() | |
self.last_average_policy = None | |
def select_action(self, env: Environment) -> Action: | |
self.last_average_policy = self.network.forward(env.observation) | |
return softmax_sample(self.last_action_or_outcome, temperature=1.0) | |
def stats(self): | |
if self.last_average_policy is None: | |
raise ValueError('No average policy was called.') | |
return self.last_average_policy | |
class StochasticMuZeroActor(Actor): | |
def __init__(self, | |
config: StochasticMuZeroConfig, | |
cacher: NetworkCacher): | |
self.config = config | |
self.cacher = cacher | |
self.training_step = -1 | |
self.network = None | |
def reset(self): | |
# Read a network from the cacher for the new episode. | |
self.training_step, self.network = self.cacher.load_network() | |
self.root = None | |
def _mask_illegal_actions(self, | |
env: Environment, | |
outputs: NetworkOutput) -> NetworkOutput: | |
"""Masks any actions which are illegal at the root.""" | |
# We mask out and keep only the legal actions. | |
masked_policy = {} | |
network_policy = outputs.probabilities | |
norm = 0 | |
for action in env.legal_actions(): | |
if action in network_policy: | |
masked_policy[action] = network_policy[action] | |
else: | |
masked_policy[action] = 0.0 | |
norm += masked_policy[action] | |
# Renormalize the masked policy. | |
masked_policy = {a: v / norm for a, v in masked_policy.items()} | |
return NetworkOutput(value=outputs.value, probabilities=masked_policy) | |
def _select_action(self, root: Node): | |
"""Selects an action given the root node.""" | |
# Get the visit count distribution. | |
actions, visit_counts = zip(*[ | |
(action, node.visit_counts) | |
for action, node in node.children.items() | |
]) | |
# Temperature | |
temperature = self.config.visit_softmax_temperature_fn(self.training_step) | |
# Compute the search policy. | |
search_policy = [v ** (1. / temperature) for v in visit_counts] | |
norm = sum(search_policy) | |
search_policy = [v / norm for v in search_policy] | |
return np.random.choice(actions, p=search_policy) | |
def select_action(self, env: Environment) -> Action: | |
"""Selects an action.""" | |
# New min max stats for the search tree. | |
min_max_stats = MinMaxStats(self.config.known_bounds) | |
# At the root of the search tree we use the representation function to | |
# obtain a hidden state given the current observation. | |
root = Node(0) | |
# Provide the history of observations to the representation network to | |
# get the initial latent state. | |
latent_state = self.network.representation(env.observation()) | |
# Compute the predictions. | |
outputs = self.network.predictions(latent_state) | |
# Keep only the legal actions. | |
outputs = self._mask_illegal_actions(env, outputs) | |
# Expand the root node. | |
expand_node(root, latent_state, outputs, env.to_play(), is_chance=False) | |
# Backpropagate the value. | |
backpropagate([root], outputs.value, env.to_play(), | |
self.config.discount, min_max_stats) | |
# We add exploration noise to the root node. | |
add_exploration_noise(self.config, root) | |
# We then run a Monte Carlo Tree Search using only action sequences and the | |
# model learned by the network. | |
run_mcts(self.config, root, ActionOutcomeHistory(env.to_play()), | |
self.network, min_max_stats) | |
# Keep track of the root to return the stats. | |
self.root = root | |
# Return an action. | |
return self._select_action(root) | |
def stats(self) -> SearchStats: | |
"""Returns the stats of the latest search.""" | |
if self.root is None: | |
raise ValueError('No search was executed.') | |
return SearchStats( | |
search_policy={ | |
action: node.visit_counts | |
for action, node in self.root.children.items() | |
}, | |
search_value=self.root.value()) | |
# Standard self-play. | |
# Each self-play job is independent of all others; it takes the latest network | |
# snapshot, produces an episode and makes it available to the training job by | |
# writing it to a shared replay buffer. | |
def run_selfplay(config: StochasticMuZeroConfig, | |
cacher: NetworkCacher, | |
replay_buffer: ReplayBuffer): | |
actor = StochasticMuZeroActor(config, cacher) | |
while True: | |
# Create a new instance of the environment. | |
env = config.environment_factory() | |
# Reset the actor. | |
actor.reset() | |
episode = [] | |
while not env.is_terminal(): | |
action = actor.select_action(env) | |
state = State( | |
observation=env.observation(), | |
reward=env.reward(env.to_play()), | |
discount=config.discount, | |
player=env.to_play(), | |
action=action, | |
search_stats=actor.stats()) | |
episode.append(state) | |
env.apply(action) | |
# Send the episode to the replay. | |
replay_buffer.save(episode) | |
# NFSP self-play. | |
def run_nfsp_selfplay(config: StochasticMuZeroConfig, | |
cacher: NetworkCacher, | |
nfsp_cacher: NetworkCacher, | |
replay_buffer: ReplayBuffer, | |
reservoir_buffer: ReservoirBuffer): | |
best_response_actor = StochasticMuZeroActor(config, cacher) | |
average_policy_actor = NFSPAveragePolicyActor(nfsp_cacher) | |
while True: | |
# Create a new instance of the environment. | |
env = config.environment_factory() | |
# Reset the actors. | |
best_response_actor.reset() | |
average_policy_actor.reset() | |
# The agent perceives the environment from a single player | |
# perspective, so we store the data for each player separately. | |
episodes = [[], []] | |
# Whether to execute actions using the average policy or the best response | |
# one for each player. | |
use_avg_policy = [ | |
np.random.uniform() > config.anticipatory_factor, | |
np.random.uniform() > config.anticipatory_factor | |
] | |
while not env.is_terminal(): | |
current_player = env.to_play() | |
best_response_action = best_response_actor.select_action(env) | |
avg_policy_action = average_policy_actor.select_action(env) | |
if use_avg_policy[current_player]: | |
action = avg_policy_action | |
else: | |
action = best_response_action | |
state = State( | |
observation=env.observation(), | |
reward=env.reward(env.to_play()), | |
discount=config.discount, | |
player=current_player, | |
action=action, | |
search_stats=best_response_actor.stats(), | |
nfsp_average_policy=average_policy_actor.stats()) | |
episodes[current_player].append(state) | |
env.apply(action) | |
# Send the episodes for each player to the replay. | |
for episode in episodes: | |
replay_buffer.save(episode) | |
# We store only the best response episodes to the reservoir buffer. | |
for use_avg_p, episode in zip(use_avg_policy, episodes): | |
if use_avg_p: | |
reservoir_buffer.save(episode) | |
################################## | |
############ Training ############ | |
class Learner(metaclass=abc.ABCMeta): | |
"""An learner to update the network weights based.""" | |
@abc.abstractmethod | |
def learn(self): | |
"""Single training step of the learner.""" | |
@abc.abstractmethod | |
def export(self) -> Union[Network, NFSPAveragePolicyNetwork]: | |
"""Exports the network.""" | |
def policy_loss(predictions, labels): | |
"""Minimizes the KL-divergence of the predictions and labels.""" | |
return 0.0 | |
def value_or_reward_loss(prediction, target): | |
"""Implements the value or reward loss for Stochastic MuZero. | |
For backgammon this is implemented as an MSE loss of scalars. | |
For 2048 and Leduc poker, we use the two hot representation proposed in | |
MuZero, and this loss is implemented as a KL divergence between the value | |
and value target representations. | |
For 2048 we also apply a hyperbolic transformation to the target (see paper | |
for more information). | |
Args: | |
prediction: The reward or value output of the network. | |
target: The reward or value target. | |
Returns: | |
The loss to minimize. | |
""" | |
return 0.0 | |
class StochasticMuZeroLearner(Learner): | |
"""Implements the learning for Stochastic MuZero.""" | |
def __init__(self, | |
config: StochasticMuZeroConfig, | |
replay_buffer: ReplayBuffer): | |
self.config = config | |
self.replay_buffer = replay_buffer | |
# Instantiate the network. | |
self.network = config.network_factory() | |
def transpose_to_time(self, batch): | |
"""Transposes the data so the leading dimension is time instead of batch.""" | |
return batch | |
def learn(self): | |
"""Applies a single training step.""" | |
batch = self.replay_buffer.sample() | |
# Transpose batch to make time the leading dimension. | |
batch = self.transpose_to_time(batch) | |
# Compute the initial step loss. | |
latent_state = self.network.representation(batch[0].observation) | |
predictions = self.network.predictions(latent_state) | |
# Computes the td target for the 0th position. | |
value_target = compute_td_target(self.config.td_steps, | |
self.config.td_lambda, | |
batch) | |
# Train the network value towards the td target. | |
total_loss = value_or_reward_loss(predictions.value, value_target) | |
# Train the network policy towards the MCTS policy. | |
total_loss += policy_loss(predictions.probabilities, | |
batch[0].search_stats.search_policy) | |
# Unroll the model for k steps. | |
for t in range(1, self.config.num_unroll_steps + 1): | |
# Condition the afterstate on the previous action. | |
afterstate = self.network.afterstate_dynamics( | |
latent_state, batch[t - 1].action) | |
afterstate_predictions = self.network.afterstate_predictions(afterstate) | |
# Call the encoder on the next observation. | |
# The encoder returns the chance code which is a discrete one hot code. | |
# The gradients flow to the encoder using a straight through estimator. | |
chance_code = self.network.encoder(batch[t].observation) | |
# The afterstate value is trained towards the previous value target | |
# but conditioned on the selected action to obtain a Q-estimate. | |
total_loss += value_or_reward_loss( | |
afterstate_predictions.value, value_target) | |
# The afterstate distribution is trained to predict the chance code | |
# generated by the encoder. | |
total_loss += policy_loss(afterstate_predictions.probabilities, | |
chance_code) | |
# Get the dynamic predictions. | |
latent_state = self.network.dynamics(afterstate, chance_code) | |
predictions = self.network.predictions(latent_state) | |
# Compute the new value target. | |
value_target = compute_td_target(self.config.td_steps, | |
self.config.td_lambda, | |
batch[t:]) | |
# The reward loss for the dynamics network. | |
total_loss += value_or_reward_loss(predictions.reward, batch[t].reward) | |
total_loss += value_or_reward_loss(predictions.value, value_target) | |
total_loss += policy_loss(predictions.probabilities, | |
batch[t].search_stats.search_policy) | |
minimize_with_adam_and_weight_decay(total_loss, | |
learning_rate=self.config.learning_rate, | |
weight_decay=self.config.weight_decay) | |
def export(self) -> Network: | |
return self.network | |
class NFSPAveragePolicyLearner(Learner): | |
"""Implements the learning for average policy in NFSP.""" | |
def __init__(self, | |
config: StochasticMuZeroConfig, | |
reservoir_buffer: ReservoirBuffer): | |
self.config = config | |
self.reservoir_buffer = reservoir_buffer | |
# Instantiate the network. | |
self.network = config.nfsp_network_factory() | |
def learn(self): | |
"""Applies a single training step.""" | |
batch = self.reservoir_buffer.sample() | |
nfsp_policy = self.network.forward(batch.observation) | |
# Move the NFSP policy towards the search policy. | |
total_loss = policy_loss(nfsp_policy, batch.search_stats.search_policy) | |
minimize_with_sgd(total_loss, | |
learning_rate=self.config.nfsp_learning_rate) | |
def export(self) -> Network: | |
return self.network | |
def train_stochastic_muzero(config: StochasticMuZeroConfig, | |
cacher: NetworkCacher, | |
replay_buffer: ReplayBuffer): | |
learner = StochasticMuZeroLearner(config, replay_buffer) | |
# Export the network so the actors can start generating experience. | |
cacher.save_network(0, learner.export()) | |
for step in range(config.training_steps): | |
# Single learning step. | |
learner.learn() | |
if step > 0 and step % config.export_network_every == 0: | |
cacher.save_network(step, learner.export()) | |
def train_nfsp_average_policy(config: StochasticMuZeroConfig, | |
cacher: NetworkCacher, | |
reservoir_buffer: ReservoirBuffer): | |
learner = NFSPAveragePolicyLearner(config, reservoir_buffer) | |
# Export the network so the actors can start generating experience. | |
cacher.save_network(0, learner.export()) | |
for step in range(config.training_steps): | |
# Single learning step. | |
learner.learn() | |
if step > 0 and step % config.export_network_every == 0: | |
cacher.save_network(step, learner.export()) | |
################################## | |
############ RL loop ############# | |
def launch_stochastic_muzero(config: StochasticMuZeroConfig): | |
"""Full RL loop for stochastic MuZero.""" | |
replay_buffer = ReplayBuffer(config) | |
cacher = NetworkCacher() | |
# Launch a learner job. | |
launch_job(lambda: train_stochastic_muzero(config, cacher, replay_buffer)) | |
# Launch the actors. | |
for _ in range(config.num_actors): | |
launch_job(lambda: run_selfplay(config, cacher, replay_buffer)) | |
def launch_nfsp_stochastic_muzero(config: StochasticMuZeroConfig): | |
"""Full RL loop for stochastic MuZero.""" | |
replay_buffer = ReplayBuffer(config) | |
reservoir_buffer = ReservoirBuffer(config) | |
# Cacher for the Stochastic MuZero network. | |
cacher = NetworkCacher() | |
# Cacher for the average policy network. | |
nfsp_cacher = NetworkCacher() | |
# Launch a learner for Stochastic MuZero. | |
launch_job(lambda: train_stochastic_muzero(config, cacher, replay_buffer)) | |
# Launch a learner for the average policy. | |
launch_job( | |
lambda: train_nfsp_average_policy(config, nfsp_cacher, reservoir_buffer)) | |
# Launch the actors. | |
def nfsp_selfplay(): | |
return run_nfsp_selfplay(config, cacher, nfsp_cacher, replay_buffer, | |
reservoir_buffer) | |
for _ in range(config.num_actors): | |
launch_job(nfsp_selfplay) | |
# Stubs to make the typechecker happy. | |
def softmax_sample(distribution, temperature: float): | |
return 0, 0 | |
def compute_td_target(td_steps, td_lambda, trajectory): | |
"""Computes the TD lambda targets given a trajectory for the 0th element. | |
In case of NFSP we use the V-Trace to correct for off-policyness | |
of generating trajectories using the average policy and training the | |
value of the best response. | |
Args: | |
td_steps: The number n of the n-step returns. | |
td_lambda: The lambda in TD(lambda). | |
trajectory: A sequence of states. | |
Returns: | |
The n-step return. | |
""" | |
return 0.0 | |
def minimize_with_sgd(loss, learning_rate): | |
"""Minimizes the loss using SGD.""" | |
def minimize_with_adam_and_weight_decay(loss, learning_rate, weight_decay): | |
"""Minimizes the loss using Adam with weight decay.""" | |
def launch_job(f): | |
"""Launches a job to run remotely.""" | |
return f() |
You can obtain the reward prediction by passing the LatentState to the
`predictions` function.
…On Thu, 18 Aug 2022 at 18:19, Fernando Ribeiro ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
Hi Julian,
On the line 118
<https://gist.github.com/Mononofu/7548d8aa4bf94e12bc7eb7662fd60b56#file-stochastic_muzero-py-L118>
the return of dynamic network expect to be a LatentState. But, the dynamic
formula (by the paper) expect the LatentState and the Reward
$s_t^{k+1}, r_t^{k+1}=g(as_t^{k},c_{t+k+1})$
During mcts line 510
<https://gist.github.com/Mononofu/7548d8aa4bf94e12bc7eb7662fd60b56#file-stochastic_muzero-py-L510>
you method expect only the LatentState.
Can I ignore the reward from dynamic model?
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/7548d8aa4bf94e12bc7eb7662fd60b56#gistcomment-4271253>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAWR6CQ6D3PMMLM5VH2I53VZZWC5ANCNFSM5655QALA>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
Hi Julian,
How should the chance outcome be masked/replaced after the end of the episode when there is no longer a target observation to pass into the encoder?
Thanks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Julian,
On the line 118 the return of dynamic network expect to be a LatentState. But, the dynamic formula (by the paper) expect the LatentState and the Reward
During mcts line 510 you method expect only the LatentState.
Can I ignore the reward from dynamic model?