Skip to content

Instantly share code, notes, and snippets.

@ttumiel
Created February 15, 2023 22:32
Show Gist options
  • Save ttumiel/ee746d6292cecb47d390fb97c3ccfa5e to your computer and use it in GitHub Desktop.
Save ttumiel/ee746d6292cecb47d390fb97c3ccfa5e to your computer and use it in GitHub Desktop.
## Base Implementation from `ppo_atari_envpool_async_jax_scan_impalanet_machado.py`
## https://github.com/vwxyzjn/cleanrl/blob/52e263887744e022c6b6d0c0de2591c85212c86a/cleanrl/ppo_atari_envpool_async_jax_scan_impalanet_machado.py
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
# https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
import argparse
import os
import random
import time
from distutils.util import strtobool
from typing import Sequence
os.environ[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
import envpool
import flax
import flax.linen as nn
import gym
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from torch.utils.tensorboard import SummaryWriter
def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="Breakout-v5",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=50000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=64,
help="the number of parallel game environments")
parser.add_argument("--async-batch-size", type=int, default=16,
help="the envpool's batch size in the async mode")
parser.add_argument("--num-steps", type=int, default=32,
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Use GAE for advantage computation")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=2,
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=2,
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.1,
help="the surrogate clipping coefficient")
parser.add_argument("--ent-coef", type=float, default=0.01,
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.5,
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_updates = args.total_timesteps // args.batch_size
# fmt: on
return args
def make_env(env_id, seed, num_envs, async_batch_size=1):
def thunk():
envs = envpool.make(
env_id,
env_type="gym",
num_envs=num_envs,
batch_size=async_batch_size,
episodic_life=False, # Machado et al. 2017 (Revisitng ALE: Eval protocols) p. 6
repeat_action_probability=0.25, # Machado et al. 2017 (Revisitng ALE: Eval protocols) p. 12
noop_max=1, # Machado et al. 2017 (Revisitng ALE: Eval protocols) p. 12 (no-op is deprecated in favor of sticky action, right?)
full_action_space=True, # Machado et al. 2017 (Revisitng ALE: Eval protocols) Tab. 5
max_episode_steps=int(108000 / 4), # Hessel et al. 2018 (Rainbow DQN), Table 3, Max frames per episode
reward_clip=True,
seed=seed,
)
envs.num_envs = num_envs
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs.is_vector_env = True
return envs
return thunk
class ResidualBlock(nn.Module):
channels: int
@nn.compact
def __call__(self, x):
inputs = x
x = nn.relu(x)
x = nn.Conv(
self.channels,
kernel_size=(3, 3),
)(x)
x = nn.relu(x)
x = nn.Conv(
self.channels,
kernel_size=(3, 3),
)(x)
return x + inputs
class ConvSequence(nn.Module):
channels: int
@nn.compact
def __call__(self, x):
x = nn.Conv(
self.channels,
kernel_size=(3, 3),
)(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
x = ResidualBlock(self.channels)(x)
x = ResidualBlock(self.channels)(x)
return x
class Network(nn.Module):
channelss: Sequence[int] = (16, 32, 32)
@nn.compact
def __call__(self, x):
x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
for channels in self.channelss:
x = ConvSequence(channels)(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
x = nn.relu(x)
return x
class Critic(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x)
class Actor(nn.Module):
action_dim: Sequence[int]
@nn.compact
def __call__(self, x):
return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)
@flax.struct.dataclass
class AgentParams:
network_params: flax.core.FrozenDict
actor_params: flax.core.FrozenDict
critic_params: flax.core.FrozenDict
def batch_obs(o):
return tree.tree_map(lambda *obs: jnp.concatenate(obs), *o)
if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, network_key, actor_key, critic_key = jax.random.split(key, 4)
# env setup
envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size)()
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
def linear_schedule(count):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches * args.update_epochs) gradient updates
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
return args.learning_rate * frac
network = Network()
actor = Actor(action_dim=envs.single_action_space.n)
critic = Critic()
network_params = network.init(network_key, np.array([envs.single_observation_space.sample()]))
agent_state = TrainState.create(
apply_fn=None,
params=AgentParams(
network_params,
actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
),
tx=optax.chain(
optax.clip_by_global_norm(args.max_grad_norm),
optax.inject_hyperparams(optax.adam)(
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
),
),
)
@jax.jit
def get_action_and_value(
agent_state: TrainState,
next_obs: np.ndarray,
key: jax.random.PRNGKey,
):
hidden = network.apply(agent_state.params.network_params, next_obs)
logits = actor.apply(agent_state.params.actor_params, hidden)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
value = critic.apply(agent_state.params.critic_params, hidden)
return action, logprob, value.squeeze(), key
@jax.jit
def get_action_and_value2(
params: flax.core.FrozenDict,
x: np.ndarray,
action: np.ndarray,
):
hidden = network.apply(params.network_params, x)
logits = actor.apply(params.actor_params, hidden)
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
logits = logits.clip(min=jnp.finfo(logits.dtype).min)
p_log_p = logits * jax.nn.softmax(logits)
entropy = -p_log_p.sum(-1)
value = critic.apply(params.critic_params, hidden).squeeze()
return logprob, entropy, value
def compute_gae_once(carry, x):
lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked = carry
(
done,
value,
eid,
reward,
) = x
nextnonterminal = 1.0 - lastdones[eid]
nextvalues = lastvalues[eid]
delta = jnp.where(final_env_id_checked[eid] == -1, 0, reward + args.gamma * nextvalues * nextnonterminal - value)
advantages = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam[eid]
final_env_ids = jnp.where(final_env_id_checked[eid] == 1, 1, 0)
final_env_id_checked = final_env_id_checked.at[eid].set(
jnp.where(final_env_id_checked[eid] == -1, 1, final_env_id_checked[eid])
)
# the last_ variables keeps track of the actual `num_steps`
lastgaelam = lastgaelam.at[eid].set(advantages)
lastdones = lastdones.at[eid].set(done)
lastvalues = lastvalues.at[eid].set(value)
return (lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked), (
advantages,
final_env_ids,
)
@jax.jit
def compute_gae(
env_ids: np.ndarray,
rewards: np.ndarray,
values: np.ndarray,
dones: np.ndarray,
):
dones = jnp.asarray(dones)
values = jnp.asarray(values)
env_ids = jnp.asarray(env_ids)
rewards = jnp.asarray(rewards)
_, B = env_ids.shape
final_env_id_checked = jnp.zeros(args.num_envs, jnp.int32) - 1
final_env_ids = jnp.zeros(B, jnp.int32)
advantages = jnp.zeros(B)
lastgaelam = jnp.zeros(args.num_envs)
lastdones = jnp.zeros(args.num_envs) + 1
lastvalues = jnp.zeros(args.num_envs)
(_, _, _, _, final_env_ids, final_env_id_checked), (advantages, final_env_ids) = jax.lax.scan(
compute_gae_once,
(
lastvalues,
lastdones,
advantages,
lastgaelam,
final_env_ids,
final_env_id_checked,
),
(
dones,
values,
env_ids,
rewards,
),
reverse=True,
)
return advantages, advantages + values, final_env_id_checked, final_env_ids
def ppo_loss(params, x, a, logp, mb_advantages, mb_returns):
newlogprob, entropy, newvalue = get_action_and_value2(params, x, a)
logratio = newlogprob - logp
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()
# Value loss
v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()
entropy_loss = entropy.mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
@jax.jit
def update_ppo(
agent_state: TrainState,
obs: list,
dones: list,
values: list,
actions: list,
logprobs: list,
env_ids: list,
rewards: list,
key: jax.random.PRNGKey,
):
# concatenating here instead of asarray + reshape, could speed up?
b_obs = batch_obs(obs)
dones = jnp.asarray(dones)
values = jnp.asarray(values)
actions = jnp.asarray(actions)
logprobs = jnp.asarray(logprobs)
env_ids = jnp.asarray(env_ids)
rewards = jnp.asarray(rewards)
# TODO: in an unlikely event, one of the envs might have not stepped at all, which may results in unexpected behavior
T, B = env_ids.shape
index_ranges = jnp.arange(T * B, dtype=jnp.int32)
next_index_ranges = jnp.zeros_like(index_ranges, dtype=jnp.int32)
last_env_ids = jnp.zeros(args.num_envs, dtype=jnp.int32) - 1
def f(carry, x):
last_env_ids, next_index_ranges = carry
env_id, index_range = x
next_index_ranges = next_index_ranges.at[last_env_ids[env_id]].set(
jnp.where(last_env_ids[env_id] != -1, index_range, next_index_ranges[last_env_ids[env_id]])
)
last_env_ids = last_env_ids.at[env_id].set(index_range)
return (last_env_ids, next_index_ranges), None
(last_env_ids, next_index_ranges), _ = jax.lax.scan(
f,
(last_env_ids, next_index_ranges),
(env_ids.reshape(-1), index_ranges),
)
# rewards is off by one time step
rewards = rewards.reshape(-1)[next_index_ranges].reshape((args.num_steps) * async_update, args.async_batch_size)
advantages, returns, _, final_env_ids = compute_gae(env_ids, rewards, values, dones)
b_inds = jnp.nonzero(final_env_ids.reshape(-1), size=(args.num_steps) * async_update * args.async_batch_size)[0]
b_actions = actions.reshape(-1)
b_logprobs = logprobs.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
def update_epoch(carry, _):
agent_state, key = carry
key, subkey = jax.random.split(key)
# taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py
def convert_data(x: jnp.ndarray):
x = jax.random.permutation(subkey, x)
x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:])
return x
def update_minibatch(agent_state, minibatch):
mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns = minibatch
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
agent_state.params,
mb_obs,
mb_actions,
mb_logprobs,
mb_advantages,
mb_returns,
)
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads)
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan(
update_minibatch,
agent_state,
tree.tree_map(convert_data, (b_obs, b_actions, b_logprobs, b_advantages, b_returns))
)
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads)
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, _) = jax.lax.scan(
update_epoch, (agent_state, key), (), length=args.update_epochs
)
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, advantages, returns, b_inds, final_env_ids, key
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
async_update = int(args.num_envs / args.async_batch_size)
# put data in the last index
episode_returns = np.zeros((args.num_envs,), dtype=np.float32)
returned_episode_returns = np.zeros((args.num_envs,), dtype=np.float32)
episode_lengths = np.zeros((args.num_envs,), dtype=np.float32)
returned_episode_lengths = np.zeros((args.num_envs,), dtype=np.float32)
envs.async_reset()
final_env_ids = np.zeros((async_update, args.async_batch_size), dtype=np.int32)
for update in range(1, args.num_updates + 2):
update_time_start = time.time()
obs = []
dones = []
actions = []
logprobs = []
values = []
env_ids = []
rewards = []
truncations = []
terminations = []
env_recv_time = 0
inference_time = 0
storage_time = 0
env_send_time = 0
# NOTE: This is a major difference from the sync version:
# at the end of the rollout phase, the sync version will have the next observation
# ready for the value bootstrap, but the async version will not have it.
# for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
# but note that the extra states are not used for the loss computation in the next iteration,
# while the sync version will use the extra state for the loss computation.
for step in range(
async_update, (args.num_steps + 1) * async_update
): # num_steps + 1 to get the states for value bootstrapping.
env_recv_time_start = time.time()
next_obs, next_reward, next_done, info = envs.recv()
env_recv_time += time.time() - env_recv_time_start
global_step += len(next_done)
env_id = info["env_id"]
inference_time_start = time.time()
action, logprob, value, key = get_action_and_value(agent_state, next_obs, key)
inference_time += time.time() - inference_time_start
env_send_time_start = time.time()
envs.send(np.array(action), env_id)
env_send_time += time.time() - env_send_time_start
storage_time_start = time.time()
obs.append(next_obs) # Change to extend?
dones.append(next_done)
values.append(value)
actions.append(action)
logprobs.append(logprob)
env_ids.append(env_id)
rewards.append(next_reward)
truncations.append(info["TimeLimit.truncated"])
terminations.append(info["terminated"])
episode_returns[env_id] += info["reward"]
returned_episode_returns[env_id] = np.where(
info["terminated"] + info["TimeLimit.truncated"], episode_returns[env_id], returned_episode_returns[env_id]
)
episode_returns[env_id] *= (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"])
episode_lengths[env_id] += 1
returned_episode_lengths[env_id] = np.where(
info["terminated"] + info["TimeLimit.truncated"], episode_lengths[env_id], returned_episode_lengths[env_id]
)
episode_lengths[env_id] *= (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"])
storage_time += time.time() - storage_time_start
avg_episodic_return = np.mean(returned_episode_returns)
# print(returned_episode_returns)
print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step)
training_time_start = time.time()
(
agent_state,
loss,
pg_loss,
v_loss,
entropy_loss,
approx_kl,
advantages,
returns,
b_inds,
final_env_ids,
key,
) = update_ppo(
agent_state,
obs,
dones,
values,
actions,
logprobs,
env_ids,
rewards,
key,
)
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
# writer.add_scalar("stats/advantages", advantages.mean().item(), global_step)
# writer.add_scalar("stats/returns", returns.mean().item(), global_step)
writer.add_scalar("stats/truncations", np.sum(truncations), global_step)
writer.add_scalar("stats/terminations", np.sum(terminations), global_step)
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step)
writer.add_scalar("losses/value_loss", v_loss[-1, -1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1, -1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1, -1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1, -1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1, -1].item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
writer.add_scalar(
"charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step
)
writer.add_scalar("stats/env_recv_time", env_recv_time, global_step)
writer.add_scalar("stats/inference_time", inference_time, global_step)
writer.add_scalar("stats/storage_time", storage_time, global_step)
writer.add_scalar("stats/env_send_time", env_send_time, global_step)
writer.add_scalar("stats/update_time", time.time() - update_time_start, global_step)
if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(
[
vars(args),
[
agent_state.params.network_params,
agent_state.params.actor_params,
agent_state.params.critic_params,
],
]
)
)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=(Network, Actor, Critic),
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)
if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
envs.close()
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment