Skip to content

Instantly share code, notes, and snippets.

View davidADSP's full-sized avatar

David Foster davidADSP

View GitHub Profile
# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory,
network: Network):
min_max_stats = MinMaxStats(config.known_bounds)
for _ in range(config.num_simulations):
history = action_history.clone()
@davidADSP
davidADSP / ego_graph.py
Last active January 27, 2021 11:28
ego_graph
import networkx as nx
# SAMPLE DATA FORMAT
#nodes = [('tensorflow', {'count': 13}),
# ('pytorch', {'count': 6}),
# ('keras', {'count': 6}),
# ('scikit', {'count': 2}),
# ('opencv', {'count': 5}),
# ('spark', {'count': 13}), ...]
def muzero(config: MuZeroConfig):
storage = SharedStorage()
replay_buffer = ReplayBuffer(config)
for _ in range(config.num_actors):
launch_job(run_selfplay, config, storage, replay_buffer)
train_network(config, storage, replay_buffer)
return storage.latest_network()
@davidADSP
davidADSP / selfplay.py
Created January 23, 2021 22:06
The step method of the SelfPlayEnv calss
class SelfPlayEnv(env):
# ...
def step(self, action):
self.render()
observation, reward, done, _ = super(SelfPlayEnv, self).step(action)
logger.debug(f'Action played by agent: {action}')
logger.debug(f'Rewards: {reward}')
logger.debug(f'Done: {done}')
@davidADSP
davidADSP / selfplay.py
Created January 23, 2021 22:05
the continue_game method of the SelfPlayEnv class
class SelfPlayEnv(env):
# ...
def continue_game(self):
while self.current_player_num != self.agent_player_num:
self.render()
action = self.current_agent.choose_action(self, choose_best_action = False, mask_invalid_actions = False)
observation, reward, done, _ = super(SelfPlayEnv, self).step(action)
logger.debug(f'Rewards: {reward}')
logger.debug(f'Done: {done}')
@davidADSP
davidADSP / selfplay.py
Last active January 23, 2021 22:04
reset wrapper for self-play environment
class SelfPlayEnv(env):
# ...
def reset(self):
super(SelfPlayEnv, self).reset()
self.setup_opponents()
if self.current_player_num != self.agent_player_num:
self.continue_game()
@davidADSP
davidADSP / train.py
Last active January 23, 2021 20:07
Training a PPO model on Pendulum
import gym
from stable_baselines import PPO1
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.callbacks import EvalCallback
env = gym.make('Pendulum-v0')
model = PPO1(MlpPolicy, env)
# Separate evaluation env
class Game(object):
"""A single episode of interaction with the environment."""
def __init__(self, action_space_size: int, discount: float):
self.environment = Environment() # Game specific environment.
self.history = []
self.rewards = []
self.child_visits = []
self.root_values = []
self.action_space_size = action_space_size
class ReplayBuffer(object):
def __init__(self, config: MuZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def sample_batch(self, num_unroll_steps: int, td_steps: int):
games = [self.sample_game() for _ in range(self.batch_size)]
game_pos = [(g, self.sample_position(g)) for g in games]
return [(g.make_image(i), g.history[i:i + num_unroll_steps],
class ReplayBuffer(object):
def __init__(self, config: MuZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)