Skip to content

Instantly share code, notes, and snippets.

@sdpkjc
Created April 23, 2024 11:31
Show Gist options
  • Save sdpkjc/71cbba3bd439d754ef9c026c31658ecd to your computer and use it in GitHub Desktop.
Save sdpkjc/71cbba3bd439d754ef9c026c31658ecd to your computer and use it in GitHub Desktop.
CleanRL's PPO (Handle truncation properly)
import argparse
import copy
import os
import random
import time
from distutils.util import strtobool
from typing import Callable
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=3e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=2048,
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=32,
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=10,
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.2,
help="the surrogate clipping coefficient")
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--ent-coef", type=float, default=0.0,
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.5,
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
# fmt: on
return args
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/normalize.py
class RunningMeanStd(nn.Module):
def __init__(self, epsilon=1e-4, shape=()):
super().__init__()
self.register_buffer("mean", torch.zeros(shape, dtype=torch.float64))
self.register_buffer("var", torch.ones(shape, dtype=torch.float64))
self.register_buffer("count", torch.tensor(epsilon, dtype=torch.float64))
def update(self, x):
x = torch.as_tensor(x, dtype=torch.float64).to(self.mean.device)
batch_mean = torch.mean(x, dim=0).to(self.mean.device)
batch_var = torch.var(x, dim=0, unbiased=False).to(self.mean.device)
batch_count = x.shape[0]
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
try:
self.num_envs = self.get_wrapper_attr("num_envs")
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.num_envs = 1
self.is_vector_env = False
if self.is_vector_env:
self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
self.enable = True
self.freeze = False
def step(self, action):
obs, rews, terminateds, truncateds, infos = self.env.step(action)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs, rews, terminateds, truncateds, infos
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
def normalize(self, obs):
if not self.freeze:
self.obs_rms.update(obs)
if self.enable:
return (obs - self.obs_rms.mean.cpu().numpy()) / np.sqrt(self.obs_rms.var.cpu().numpy() + self.epsilon)
return obs
class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(
self,
env: gym.Env,
gamma: float = 0.99,
epsilon: float = 1e-8,
):
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
try:
self.num_envs = self.get_wrapper_attr("num_envs")
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.num_envs = 1
self.is_vector_env = False
self.return_rms = RunningMeanStd(shape=())
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
self.enable = True
self.freeze = False
def step(self, action):
obs, rews, terminateds, truncateds, infos = self.env.step(action)
if not self.is_vector_env:
rews = np.array([rews])
self.returns = self.returns * self.gamma * (1 - terminateds) + rews
rews = self.normalize(rews)
if not self.is_vector_env:
rews = rews[0]
return obs, rews, terminateds, truncateds, infos
def reset(self, **kwargs):
# self.returns = np.zeros(self.num_envs)
return self.env.reset(**kwargs)
def normalize(self, rews):
if not self.freeze:
self.return_rms.update(self.returns)
if self.enable:
return rews / np.sqrt(self.return_rms.var.cpu().numpy() + self.epsilon)
return rews
def get_returns(self):
return self.returns
def evaluate(
model_path: str,
make_env: Callable,
env_id: str,
eval_episodes: int,
run_name: str,
Model: torch.nn.Module,
device: torch.device = torch.device("cpu"),
capture_video: bool = True,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name)])
agent = Model(envs).to(device)
agent.load_state_dict(torch.load(model_path, map_location=device))
agent.eval()
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, agent.obs_rms)])
obs, _ = envs.reset()
episodic_returns, episodic_lengths = [], []
while len(episodic_returns) < eval_episodes:
actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
next_obs, _, _, _, infos = envs.step(actions.cpu().numpy())
if "final_info" in infos:
for info in infos["final_info"]:
if "episode" not in info:
continue
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
episodic_lengths += [info["episode"]["l"]]
obs = next_obs
return episodic_returns, episodic_lengths
def make_env(env_id, idx, capture_video, run_name, gamma):
def thunk():
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = gym.wrappers.ClipAction(env)
env = NormalizeObservation(env)
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
env = NormalizeReward(env, gamma=gamma)
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
return env
return thunk
def make_eval_env(env_id, idx, capture_video, run_name, obs_rms=None):
def thunk():
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = gym.wrappers.ClipAction(env)
env = NormalizeObservation(env)
if obs_rms is not None:
env.obs_rms = copy.deepcopy(obs_rms)
env.freeze = True
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
return env
return thunk
def get_rms(env):
obs_rms, return_rms = None, None
env_point = env
while hasattr(env_point, "env"):
if isinstance(env_point, NormalizeObservation):
obs_rms = copy.deepcopy(env_point.obs_rms)
break
env_point = env_point.env
else:
raise RuntimeError("can't find NormalizeObservation")
env_point = env
while hasattr(env_point, "env"):
if isinstance(env_point, NormalizeReward):
return_rms = copy.deepcopy(env_point.return_rms)
break
env_point = env_point.env
else:
raise RuntimeError("can't find NormalizeReward")
return obs_rms, return_rms
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
)
self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
self.obs_rms = RunningMeanStd(shape=envs.single_observation_space.shape)
def get_value(self, x):
return self.critic(x)
def get_action_and_value(self, x, action=None):
action_mean = self.actor_mean(x)
action_logstd = self.actor_logstd.expand_as(action_mean)
action_std = torch.exp(action_logstd)
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
obs[step] = next_obs
dones[step] = next_done
# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
# https://github.com/DLR-RM/stable-baselines3/pull/658
for idx, trunc in enumerate(truncations):
if trunc and not terminations[idx]:
real_next_obs = infos["final_observation"][idx]
with torch.no_grad():
terminal_value = agent.get_value(torch.Tensor(real_next_obs).to(device)).reshape(1, -1)[0][0]
rewards[step][idx] += args.gamma * terminal_value
if global_step % (5000 // args.num_envs * args.num_envs) == 0:
obs_rms, return_rms = get_rms(envs.envs[0])
agent.obs_rms = copy.deepcopy(get_rms(envs.envs[0])[0])
model_path = f"runs/{run_name}/{args.exp_name}-{global_step}.cleanrl_model"
torch.save(agent.state_dict(), model_path)
print(f"model saved to {model_path}")
episodic_returns, episodic_lengths = evaluate(
model_path,
make_eval_env,
args.env_id,
eval_episodes=3,
run_name=f"{run_name}-eval",
Model=Agent,
device=device,
capture_video=False,
)
print(episodic_returns, episodic_lengths)
writer.add_scalar("charts/eval/episodic_return", np.mean(episodic_returns), global_step)
writer.add_scalar("charts/eval/episodic_length", np.mean(episodic_lengths), global_step)
# Only print when at least 1 env is done
if "final_info" not in infos:
continue
for info in infos["final_info"]:
# Skip the envs that are not done
if info is None:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
mb_advantages = b_advantages[mb_inds]
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
entropy_loss = entropy.mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
if args.target_kl is not None:
if approx_kl > args.target_kl:
break
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
writer.add_scalar("losses/explained_variance", explained_var, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.save_model:
agent.obs_rms = copy.deepcopy(get_rms(envs.envs[0])[0])
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save(agent.state_dict(), model_path)
print(f"model saved to {model_path}")
episodic_returns, episodic_lengths = evaluate(
model_path,
make_eval_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=Agent,
device=device,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)
if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
envs.close()
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment