This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Core Monte Carlo Tree Search algorithm. | |
# To decide on an action, we run N simulations, always starting at the root of | |
# the search tree and traversing the tree according to the UCB formula until we | |
# reach a leaf node. | |
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory, | |
network: Network): | |
min_max_stats = MinMaxStats(config.known_bounds) | |
for _ in range(config.num_simulations): | |
history = action_history.clone() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import networkx as nx | |
# SAMPLE DATA FORMAT | |
#nodes = [('tensorflow', {'count': 13}), | |
# ('pytorch', {'count': 6}), | |
# ('keras', {'count': 6}), | |
# ('scikit', {'count': 2}), | |
# ('opencv', {'count': 5}), | |
# ('spark', {'count': 13}), ...] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def muzero(config: MuZeroConfig): | |
storage = SharedStorage() | |
replay_buffer = ReplayBuffer(config) | |
for _ in range(config.num_actors): | |
launch_job(run_selfplay, config, storage, replay_buffer) | |
train_network(config, storage, replay_buffer) | |
return storage.latest_network() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfPlayEnv(env): | |
# ... | |
def step(self, action): | |
self.render() | |
observation, reward, done, _ = super(SelfPlayEnv, self).step(action) | |
logger.debug(f'Action played by agent: {action}') | |
logger.debug(f'Rewards: {reward}') | |
logger.debug(f'Done: {done}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfPlayEnv(env): | |
# ... | |
def continue_game(self): | |
while self.current_player_num != self.agent_player_num: | |
self.render() | |
action = self.current_agent.choose_action(self, choose_best_action = False, mask_invalid_actions = False) | |
observation, reward, done, _ = super(SelfPlayEnv, self).step(action) | |
logger.debug(f'Rewards: {reward}') | |
logger.debug(f'Done: {done}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfPlayEnv(env): | |
# ... | |
def reset(self): | |
super(SelfPlayEnv, self).reset() | |
self.setup_opponents() | |
if self.current_player_num != self.agent_player_num: | |
self.continue_game() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from stable_baselines import PPO1 | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.common.callbacks import EvalCallback | |
env = gym.make('Pendulum-v0') | |
model = PPO1(MlpPolicy, env) | |
# Separate evaluation env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Game(object): | |
"""A single episode of interaction with the environment.""" | |
def __init__(self, action_space_size: int, discount: float): | |
self.environment = Environment() # Game specific environment. | |
self.history = [] | |
self.rewards = [] | |
self.child_visits = [] | |
self.root_values = [] | |
self.action_space_size = action_space_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ReplayBuffer(object): | |
def __init__(self, config: MuZeroConfig): | |
self.window_size = config.window_size | |
self.batch_size = config.batch_size | |
self.buffer = [] | |
def sample_batch(self, num_unroll_steps: int, td_steps: int): | |
games = [self.sample_game() for _ in range(self.batch_size)] | |
game_pos = [(g, self.sample_position(g)) for g in games] | |
return [(g.make_image(i), g.history[i:i + num_unroll_steps], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ReplayBuffer(object): | |
def __init__(self, config: MuZeroConfig): | |
self.window_size = config.window_size | |
self.batch_size = config.batch_size | |
self.buffer = [] | |
def save_game(self, game): | |
if len(self.buffer) > self.window_size: | |
self.buffer.pop(0) |
NewerOlder