Skip to content

Instantly share code, notes, and snippets.

@benblack769
Last active April 22, 2021 20:23
Show Gist options
  • Save benblack769/cbf4c0a674ad24d0e095263a0b553726 to your computer and use it in GitHub Desktop.
Save benblack769/cbf4c0a674ad24d0e095263a0b553726 to your computer and use it in GitHub Desktop.
import copy
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from all.approximation import QDist, FixedTarget
from all.agents import Rainbow, RainbowTestAgent
from all.bodies import DeepmindAtariBody
from all.logging import DummyWriter
from all.memory import PrioritizedReplayBuffer, NStepReplayBuffer
from all.optim import LinearScheduler
from all.presets.atari.models import nature_rainbow
from all.presets.preset import Preset
from all.presets import PresetBuilder
from all.agents.independent import IndependentMultiagent
default_hyperparameters = {
"discount_factor": 0.99,
"lr": 6.25e-5,
"eps": 1.5e-4,
# Training settings
"minibatch_size": 32,
"update_frequency": 4,
"target_update_frequency": 1000,
# Replay buffer settings
"replay_start_size": 80000,
"replay_buffer_size": 1000000,
# Explicit exploration
"initial_exploration": 0.02,
"final_exploration": 0.,
"test_exploration": 0.001,
# Prioritized replay settings
"alpha": 0.5,
"beta": 0.5,
# Multi-step learning
"n_steps": 3,
# Distributional RL
"atoms": 51,
"v_min": -10,
"v_max": 10,
# Noisy Nets
"sigma": 0.5,
# Model construction
"model_constructor": nature_rainbow
}
class RainbowAtariPreset(Preset):
"""
Rainbow DQN Atari Preset.
Args:
env (all.environments.AtariEnvironment): The environment for which to construct the agent.
device (torch.device, optional): The device on which to load the agent.
Keyword Args:
discount_factor (float): Discount factor for future rewards.
lr (float): Learning rate for the Adam optimizer.
eps (float): Stability parameters for the Adam optimizer.
minibatch_size (int): Number of experiences to sample in each training update.
update_frequency (int): Number of timesteps per training update.
target_update_frequency (int): Number of timesteps between updates the target network.
replay_start_size (int): Number of experiences in replay buffer when training begins.
replay_buffer_size (int): Maximum number of experiences to store in the replay buffer.
initial_exploration (float): Initial probability of choosing a random action,
decayed over course of training.
final_exploration (float): Final probability of choosing a random action.
alpha (float): Amount of prioritization in the prioritized experience replay buffer.
(0 = no prioritization, 1 = full prioritization)
beta (float): The strength of the importance sampling correction for prioritized experience replay.
(0 = no correction, 1 = full correction)
n_steps (int): The number of steps for n-step Q-learning.
atoms (int): The number of atoms in the categorical distribution used to represent
the distributional value function.
v_min (int): The expected return corresponding to the smallest atom.
v_max (int): The expected return correspodning to the larget atom.
sigma (float): Initial noisy network noise.
model_constructor (function): The function used to construct the neural model.
"""
def __init__(self, env, name, device="cuda", **hyperparameters):
hyperparameters = {**default_hyperparameters, **hyperparameters}
super().__init__(env, name, hyperparameters)
self.model = hyperparameters['model_constructor'](env, frames=8, atoms=hyperparameters["atoms"], sigma=hyperparameters["sigma"]).to(device)
self.hyperparameters = hyperparameters
self.n_actions = env.action_space.n
self.device = device
self.name = name
self.agent_names = env.agents
def agent(self, writer=DummyWriter(), train_steps=float('inf')):
n_updates = (train_steps - self.hyperparameters['replay_start_size']) / self.hyperparameters['update_frequency']
optimizer = Adam(
self.model.parameters(),
lr=self.hyperparameters['lr'],
eps=self.hyperparameters['eps']
)
q_dist = QDist(
self.model,
optimizer,
self.n_actions,
self.hyperparameters['atoms'],
scheduler=CosineAnnealingLR(optimizer, n_updates),
v_min=self.hyperparameters['v_min'],
v_max=self.hyperparameters['v_max'],
target=FixedTarget(self.hyperparameters['target_update_frequency']),
writer=writer,
)
replay_buffer = NStepReplayBuffer(
self.hyperparameters['n_steps'],
self.hyperparameters['discount_factor'],
PrioritizedReplayBuffer(
self.hyperparameters['replay_buffer_size'],
alpha=self.hyperparameters['alpha'],
beta=self.hyperparameters['beta'],
device=self.device,
store_device="cpu"
)
)
def make_agent(agent_id):
return DeepmindAtariBody(
IndicatorBody(
Rainbow(
q_dist,
replay_buffer,
exploration=LinearScheduler(
self.hyperparameters['initial_exploration'],
self.hyperparameters['final_exploration'],
0,
train_steps - self.hyperparameters['replay_start_size'],
name="exploration",
writer=writer
),
discount_factor=self.hyperparameters['discount_factor'] ** self.hyperparameters["n_steps"],
minibatch_size=self.hyperparameters['minibatch_size'],
replay_start_size=self.hyperparameters['replay_start_size'],
update_frequency=self.hyperparameters['update_frequency'],
writer=writer,
),
self.agent_names.index(agent_id),
len(self.agent_names)
),
lazy_frames=True,
episodic_lives=True
)
return IndependentMultiagent({
agent_id : make_agent(agent_id)
for agent_id in self.agent_names
})
def test_agent(self):
q_dist = QDist(
copy.deepcopy(self.model),
None,
self.n_actions,
self.hyperparameters['atoms'],
v_min=self.hyperparameters['v_min'],
v_max=self.hyperparameters['v_max'],
)
def make_agent(agent_id):
return DeepmindAtariBody(
IndicatorBody(
RainbowTestAgent(q_dist, self.n_actions, self.hyperparameters["test_exploration"]),
self.agent_names.index(agent_id),
len(self.agent_names)
)
)
return IndependentMultiagent({
agent_id : make_agent(agent_id)
for agent_id in self.agent_names
})
rainbow = PresetBuilder('rainbow', default_hyperparameters, RainbowAtariPreset)
import argparse
from all.environments import MultiagentPettingZooEnv
from all.experiments.multiagent_env_experiment import MultiagentEnvExperiment
from all.presets import atari
import numpy as np
# from all.experiment import run_
from supersuit.aec_wrappers import ObservationWrapper
class InvertColorAgentIndicator(ObservationWrapper):
def _check_wrapper_params(self):
assert self.observation_spaces[self.possible_agents[0]].high.dtype == np.dtype('uint8')
return
def _modify_spaces(self):
return
def _modify_observation(self, agent, observation):
max_num_agents = len(self.possible_agents)
if max_num_agents == 2:
if agent == self.possible_agents[1]:
return self.observation_spaces[agent].high - observation
else:
return observation
elif max_num_agents == 4:
if agent == self.possible_agents:
return np.uint8(255//4)+observation
from all.core import State, StateArray
from all.bodies._body import Body
import torch
import os
from all.bodies.vision import LazyState, TensorDeviceCache
class IndicatorState(State):
@classmethod
def from_state(cls, state, frames, to_cache, agent_idx):
state = IndicatorState(state, device=frames[0].device)
state.to_cache = to_cache
state.agent_idx = agent_idx
state['observation'] = frames
return state
def __getitem__(self, key):
if key == 'observation':
obs = dict.__getitem__(self, key)
if not torch.is_tensor(obs):
obs = torch.cat(dict.__getitem__(self, key), dim=0)
indicator = torch.zeros_like(obs)
indicator[self.agent_idx] = 255
return torch.cat([obs, indicator], dim=0)
return super().__getitem__(key)
def update(self, key, value):
x = {}
for k in self.keys():
if not k == key:
x[k] = super().__getitem__(k)
x[key] = value
state = IndicatorState(x, device=self.device)
state.to_cache = self.to_cache
state.agent_idx = self.agent_idx
return state
def to(self, device):
if device == self.device:
return self
x = {}
for key, value in self.items():
if key == 'observation':
x[key] = [self.to_cache.convert(v, device) for v in value]
# x[key] = [v.to(device) for v in value]#torch.cat(value,axis=0).to(device)
elif torch.is_tensor(value):
x[key] = value.to(device)
else:
x[key] = value
state = IndicatorState.from_state(x, x['observation'], self.to_cache, self.agent_idx)
return state
class IndicatorBody(Body):
def __init__(self, agent, agent_idx, num_agents):
super().__init__(agent)
self.agent_idx = agent_idx
self.num_agents = num_agents
self.to_cache = TensorDeviceCache(max_size=32)
def process_state(self, state):
new_state = IndicatorState.from_state(state, dict.__getitem__(state,'observation'), self.to_cache, self.agent_idx)
return new_state
class DummyEnv():
def __init__(self, state_space, action_space, agents):
self.state_space = state_space
self.action_space = action_space
self.agents = agents
def main():
parser = argparse.ArgumentParser(description="Run an multiagent Atari benchmark.")
parser.add_argument("env", help="Name of the Atari game (e.g. Pong).")
parser.add_argument(
"--device",
default="cuda",
help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).",
)
parser.add_argument(
"--replay_buffer_size",
default=1000000,
help="The size of the replay buffer, if applicable",
)
parser.add_argument(
"--frames", type=int, default=50e6, help="The number of training frames."
)
parser.add_argument(
"--render", action="store_true", default=False, help="Render the environment."
)
args = parser.parse_args()
from pettingzoo import atari
import importlib
from supersuit import resize_v0, frame_skip_v0, reshape_v0, max_observation_v0
env = importlib.import_module('pettingzoo.atari.{}'.format(args.env)).env(obs_type='grayscale_image')
env = max_observation_v0(env, 2)
env = frame_skip_v0(env, 4)
env = InvertColorAgentIndicator(env)
env = resize_v0(env, 84, 84)
env = reshape_v0(env, (1, 84, 84))
agent0 = env.possible_agents[0]
obs_space = env.observation_spaces[agent0]
act_space = env.action_spaces[agent0]
for agent in env.possible_agents:
assert obs_space == env.observation_spaces[agent]
assert act_space == env.action_spaces[agent]
env_agents = env.possible_agents
env = MultiagentPettingZooEnv(env, args.env)
preset = rainbow.env(env).hyperparameters(
replay_buffer_size=args.replay_buffer_size,
replay_start_size=80000,
).device(args.device).env(
DummyEnv(
obs_space, act_space, env_agents
)
).build()
experiment = MultiagentEnvExperiment(
preset,
env,
write_loss=False,
render=args.render,
)
# run_experiment()
os.mkdir("checkpoint")
num_frames_train = int(args.frames)
frames_per_save = 200000
for frame in range(0,num_frames_train,frames_per_save):
experiment.train(frames=frame)
torch.save(preset, f"checkpoint/{frame+frames_per_save:08d}.pt")
# experiment.test(episodes=5)
experiment._save_model()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment