Last active
April 22, 2021 20:23
-
-
Save benblack769/cbf4c0a674ad24d0e095263a0b553726 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from torch.optim import Adam | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
from all.approximation import QDist, FixedTarget | |
from all.agents import Rainbow, RainbowTestAgent | |
from all.bodies import DeepmindAtariBody | |
from all.logging import DummyWriter | |
from all.memory import PrioritizedReplayBuffer, NStepReplayBuffer | |
from all.optim import LinearScheduler | |
from all.presets.atari.models import nature_rainbow | |
from all.presets.preset import Preset | |
from all.presets import PresetBuilder | |
from all.agents.independent import IndependentMultiagent | |
default_hyperparameters = { | |
"discount_factor": 0.99, | |
"lr": 6.25e-5, | |
"eps": 1.5e-4, | |
# Training settings | |
"minibatch_size": 32, | |
"update_frequency": 4, | |
"target_update_frequency": 1000, | |
# Replay buffer settings | |
"replay_start_size": 80000, | |
"replay_buffer_size": 1000000, | |
# Explicit exploration | |
"initial_exploration": 0.02, | |
"final_exploration": 0., | |
"test_exploration": 0.001, | |
# Prioritized replay settings | |
"alpha": 0.5, | |
"beta": 0.5, | |
# Multi-step learning | |
"n_steps": 3, | |
# Distributional RL | |
"atoms": 51, | |
"v_min": -10, | |
"v_max": 10, | |
# Noisy Nets | |
"sigma": 0.5, | |
# Model construction | |
"model_constructor": nature_rainbow | |
} | |
class RainbowAtariPreset(Preset): | |
""" | |
Rainbow DQN Atari Preset. | |
Args: | |
env (all.environments.AtariEnvironment): The environment for which to construct the agent. | |
device (torch.device, optional): The device on which to load the agent. | |
Keyword Args: | |
discount_factor (float): Discount factor for future rewards. | |
lr (float): Learning rate for the Adam optimizer. | |
eps (float): Stability parameters for the Adam optimizer. | |
minibatch_size (int): Number of experiences to sample in each training update. | |
update_frequency (int): Number of timesteps per training update. | |
target_update_frequency (int): Number of timesteps between updates the target network. | |
replay_start_size (int): Number of experiences in replay buffer when training begins. | |
replay_buffer_size (int): Maximum number of experiences to store in the replay buffer. | |
initial_exploration (float): Initial probability of choosing a random action, | |
decayed over course of training. | |
final_exploration (float): Final probability of choosing a random action. | |
alpha (float): Amount of prioritization in the prioritized experience replay buffer. | |
(0 = no prioritization, 1 = full prioritization) | |
beta (float): The strength of the importance sampling correction for prioritized experience replay. | |
(0 = no correction, 1 = full correction) | |
n_steps (int): The number of steps for n-step Q-learning. | |
atoms (int): The number of atoms in the categorical distribution used to represent | |
the distributional value function. | |
v_min (int): The expected return corresponding to the smallest atom. | |
v_max (int): The expected return correspodning to the larget atom. | |
sigma (float): Initial noisy network noise. | |
model_constructor (function): The function used to construct the neural model. | |
""" | |
def __init__(self, env, name, device="cuda", **hyperparameters): | |
hyperparameters = {**default_hyperparameters, **hyperparameters} | |
super().__init__(env, name, hyperparameters) | |
self.model = hyperparameters['model_constructor'](env, frames=8, atoms=hyperparameters["atoms"], sigma=hyperparameters["sigma"]).to(device) | |
self.hyperparameters = hyperparameters | |
self.n_actions = env.action_space.n | |
self.device = device | |
self.name = name | |
self.agent_names = env.agents | |
def agent(self, writer=DummyWriter(), train_steps=float('inf')): | |
n_updates = (train_steps - self.hyperparameters['replay_start_size']) / self.hyperparameters['update_frequency'] | |
optimizer = Adam( | |
self.model.parameters(), | |
lr=self.hyperparameters['lr'], | |
eps=self.hyperparameters['eps'] | |
) | |
q_dist = QDist( | |
self.model, | |
optimizer, | |
self.n_actions, | |
self.hyperparameters['atoms'], | |
scheduler=CosineAnnealingLR(optimizer, n_updates), | |
v_min=self.hyperparameters['v_min'], | |
v_max=self.hyperparameters['v_max'], | |
target=FixedTarget(self.hyperparameters['target_update_frequency']), | |
writer=writer, | |
) | |
replay_buffer = NStepReplayBuffer( | |
self.hyperparameters['n_steps'], | |
self.hyperparameters['discount_factor'], | |
PrioritizedReplayBuffer( | |
self.hyperparameters['replay_buffer_size'], | |
alpha=self.hyperparameters['alpha'], | |
beta=self.hyperparameters['beta'], | |
device=self.device, | |
store_device="cpu" | |
) | |
) | |
def make_agent(agent_id): | |
return DeepmindAtariBody( | |
IndicatorBody( | |
Rainbow( | |
q_dist, | |
replay_buffer, | |
exploration=LinearScheduler( | |
self.hyperparameters['initial_exploration'], | |
self.hyperparameters['final_exploration'], | |
0, | |
train_steps - self.hyperparameters['replay_start_size'], | |
name="exploration", | |
writer=writer | |
), | |
discount_factor=self.hyperparameters['discount_factor'] ** self.hyperparameters["n_steps"], | |
minibatch_size=self.hyperparameters['minibatch_size'], | |
replay_start_size=self.hyperparameters['replay_start_size'], | |
update_frequency=self.hyperparameters['update_frequency'], | |
writer=writer, | |
), | |
self.agent_names.index(agent_id), | |
len(self.agent_names) | |
), | |
lazy_frames=True, | |
episodic_lives=True | |
) | |
return IndependentMultiagent({ | |
agent_id : make_agent(agent_id) | |
for agent_id in self.agent_names | |
}) | |
def test_agent(self): | |
q_dist = QDist( | |
copy.deepcopy(self.model), | |
None, | |
self.n_actions, | |
self.hyperparameters['atoms'], | |
v_min=self.hyperparameters['v_min'], | |
v_max=self.hyperparameters['v_max'], | |
) | |
def make_agent(agent_id): | |
return DeepmindAtariBody( | |
IndicatorBody( | |
RainbowTestAgent(q_dist, self.n_actions, self.hyperparameters["test_exploration"]), | |
self.agent_names.index(agent_id), | |
len(self.agent_names) | |
) | |
) | |
return IndependentMultiagent({ | |
agent_id : make_agent(agent_id) | |
for agent_id in self.agent_names | |
}) | |
rainbow = PresetBuilder('rainbow', default_hyperparameters, RainbowAtariPreset) | |
import argparse | |
from all.environments import MultiagentPettingZooEnv | |
from all.experiments.multiagent_env_experiment import MultiagentEnvExperiment | |
from all.presets import atari | |
import numpy as np | |
# from all.experiment import run_ | |
from supersuit.aec_wrappers import ObservationWrapper | |
class InvertColorAgentIndicator(ObservationWrapper): | |
def _check_wrapper_params(self): | |
assert self.observation_spaces[self.possible_agents[0]].high.dtype == np.dtype('uint8') | |
return | |
def _modify_spaces(self): | |
return | |
def _modify_observation(self, agent, observation): | |
max_num_agents = len(self.possible_agents) | |
if max_num_agents == 2: | |
if agent == self.possible_agents[1]: | |
return self.observation_spaces[agent].high - observation | |
else: | |
return observation | |
elif max_num_agents == 4: | |
if agent == self.possible_agents: | |
return np.uint8(255//4)+observation | |
from all.core import State, StateArray | |
from all.bodies._body import Body | |
import torch | |
import os | |
from all.bodies.vision import LazyState, TensorDeviceCache | |
class IndicatorState(State): | |
@classmethod | |
def from_state(cls, state, frames, to_cache, agent_idx): | |
state = IndicatorState(state, device=frames[0].device) | |
state.to_cache = to_cache | |
state.agent_idx = agent_idx | |
state['observation'] = frames | |
return state | |
def __getitem__(self, key): | |
if key == 'observation': | |
obs = dict.__getitem__(self, key) | |
if not torch.is_tensor(obs): | |
obs = torch.cat(dict.__getitem__(self, key), dim=0) | |
indicator = torch.zeros_like(obs) | |
indicator[self.agent_idx] = 255 | |
return torch.cat([obs, indicator], dim=0) | |
return super().__getitem__(key) | |
def update(self, key, value): | |
x = {} | |
for k in self.keys(): | |
if not k == key: | |
x[k] = super().__getitem__(k) | |
x[key] = value | |
state = IndicatorState(x, device=self.device) | |
state.to_cache = self.to_cache | |
state.agent_idx = self.agent_idx | |
return state | |
def to(self, device): | |
if device == self.device: | |
return self | |
x = {} | |
for key, value in self.items(): | |
if key == 'observation': | |
x[key] = [self.to_cache.convert(v, device) for v in value] | |
# x[key] = [v.to(device) for v in value]#torch.cat(value,axis=0).to(device) | |
elif torch.is_tensor(value): | |
x[key] = value.to(device) | |
else: | |
x[key] = value | |
state = IndicatorState.from_state(x, x['observation'], self.to_cache, self.agent_idx) | |
return state | |
class IndicatorBody(Body): | |
def __init__(self, agent, agent_idx, num_agents): | |
super().__init__(agent) | |
self.agent_idx = agent_idx | |
self.num_agents = num_agents | |
self.to_cache = TensorDeviceCache(max_size=32) | |
def process_state(self, state): | |
new_state = IndicatorState.from_state(state, dict.__getitem__(state,'observation'), self.to_cache, self.agent_idx) | |
return new_state | |
class DummyEnv(): | |
def __init__(self, state_space, action_space, agents): | |
self.state_space = state_space | |
self.action_space = action_space | |
self.agents = agents | |
def main(): | |
parser = argparse.ArgumentParser(description="Run an multiagent Atari benchmark.") | |
parser.add_argument("env", help="Name of the Atari game (e.g. Pong).") | |
parser.add_argument( | |
"--device", | |
default="cuda", | |
help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0).", | |
) | |
parser.add_argument( | |
"--replay_buffer_size", | |
default=1000000, | |
help="The size of the replay buffer, if applicable", | |
) | |
parser.add_argument( | |
"--frames", type=int, default=50e6, help="The number of training frames." | |
) | |
parser.add_argument( | |
"--render", action="store_true", default=False, help="Render the environment." | |
) | |
args = parser.parse_args() | |
from pettingzoo import atari | |
import importlib | |
from supersuit import resize_v0, frame_skip_v0, reshape_v0, max_observation_v0 | |
env = importlib.import_module('pettingzoo.atari.{}'.format(args.env)).env(obs_type='grayscale_image') | |
env = max_observation_v0(env, 2) | |
env = frame_skip_v0(env, 4) | |
env = InvertColorAgentIndicator(env) | |
env = resize_v0(env, 84, 84) | |
env = reshape_v0(env, (1, 84, 84)) | |
agent0 = env.possible_agents[0] | |
obs_space = env.observation_spaces[agent0] | |
act_space = env.action_spaces[agent0] | |
for agent in env.possible_agents: | |
assert obs_space == env.observation_spaces[agent] | |
assert act_space == env.action_spaces[agent] | |
env_agents = env.possible_agents | |
env = MultiagentPettingZooEnv(env, args.env) | |
preset = rainbow.env(env).hyperparameters( | |
replay_buffer_size=args.replay_buffer_size, | |
replay_start_size=80000, | |
).device(args.device).env( | |
DummyEnv( | |
obs_space, act_space, env_agents | |
) | |
).build() | |
experiment = MultiagentEnvExperiment( | |
preset, | |
env, | |
write_loss=False, | |
render=args.render, | |
) | |
# run_experiment() | |
os.mkdir("checkpoint") | |
num_frames_train = int(args.frames) | |
frames_per_save = 200000 | |
for frame in range(0,num_frames_train,frames_per_save): | |
experiment.train(frames=frame) | |
torch.save(preset, f"checkpoint/{frame+frames_per_save:08d}.pt") | |
# experiment.test(episodes=5) | |
experiment._save_model() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment