Skip to content

Instantly share code, notes, and snippets.

@aadharna
Last active June 30, 2023 15:30
Show Gist options
  • Save aadharna/fe6d4c6cd18d090757247fc679ebc308 to your computer and use it in GitHub Desktop.
Save aadharna/fe6d4c6cd18d090757247fc679ebc308 to your computer and use it in GitHub Desktop.
cleanRL

For some reason I'm getting the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

This error is coming from the advantage calculations needing masks in a multi-agent setting.

However, this started coming up in response to me attempting to fix an error where the NN would output nan values which would then break the Categorical Distribution for action sampling.

I know the file is long, sorry.

Note, the number of epochs and minibatches are both set to 1, so there should only be one backward call per dataset collection so backward shouldn't be able to be called twice on the same batch.

WANDB_ENV_VAR = "WANDB_API_KEY"
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy
import os
import time
import random
import pickle
import argparse
import itertools
from uuid import uuid4
from dataclasses import dataclass
from distutils.util import strtobool
from collections import defaultdict, OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
torch.autograd.set_detect_anomaly(True)
from rating import OpenSkillRating
try:
import pyspiel
from open_spiel.python.rl_environment import Environment
except ImportError:
raise Exception("Please install open_spiel python package.")
import pettingzoo
import supersuit as ss
from c4.ppo import update_model, eval_agent_against_archive, generate_data
from c4.c4_spiel_wrapper import OpenSpielCompatibleSkipHandleEnv
def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=654,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="c4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=500000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=5e-6,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=64,
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=1,
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=1,
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.5,
help="the surrogate clipping coefficient")
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--ent-coef", type=float, default=0.01,
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.01,
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.05,
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
# novelty search specific arguments
parser.add_argument("--n_eval_matches", type=int, default=1, help="number of eval matches against each opponent")
# ranking specific arguments
parser.add_argument("--mu", type=int, default=1000, help="mu for elo ranking")
parser.add_argument("--anchor_mu", type=int, default=1500, help="anchor mu for elo ranking")
parser.add_argument("--sigma", type=float, default=100/3, help="sigma for elo ranking")
args = parser.parse_args()
args.batch_size = int(args.num_steps * 1) # 15 is game length for i-RPSW
args.minibatch_size = int(args.batch_size // args.num_minibatches)
# fmt: on
return args
class InfToIntWrapper(pettingzoo.AECEnv):
def __init__(self, env):
super().__init__()
self.env = env
self.action_space = env.action_space
self.observation_space = env.observation_space
self.agent_iter = env.agent_iter
self.metadata = env.metadata
self.render = env.render
def step(self, action):
return self.env.step(action)
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def last(self, observe: bool = True):
obs, reward, done, trunc, info = self.env.last()
obs[obs == np.inf] = 1
return obs, reward, done, trunc, info
def _get_action_name_from_id(action_id):
action_name = None
if action_id == 0:
action_name = "c0"
elif action_id == 1:
action_name = "c1"
elif action_id == 2:
action_name = "c2"
elif action_id == 3:
action_name = "c3"
elif action_id == 4:
action_name = "c4"
elif action_id == 5:
action_name = "c5"
elif action_id == 6:
action_name = "c6"
return action_name
def anneal_lr(optimizer, update, num_updates, learning_rate):
"""
Anneals the learning rate linearly from args.learning_rate down to 0.
"""
frac = 1.0 - (update - 1.0) / num_updates
lrnow = frac * learning_rate
optimizer.param_groups[0]["lr"] = lrnow
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(envs.single_observation_space.shape[0], 512)),
nn.ReLU(),
layer_init(nn.Linear(512, 256)),
nn.ReLU(),
layer_init(nn.Linear(256, 1), std=1.0),
)
self.actor = nn.Sequential(
layer_init(nn.Linear(envs.single_observation_space.shape[0], 512)),
nn.ReLU(),
layer_init(nn.Linear(512, 256)),
nn.ReLU(),
layer_init(nn.Linear(256, envs.single_action_space.n), std=0.01),
)
def get_value(self, x):
return self.critic(x)
def get_action_and_value(self, x, action=None):
if x.ndim == 1:
x = x.unsqueeze(0)
logits = self.actor(x)
probs = torch.distributions.Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x), logits.cpu().detach()
if __name__ == "__main__":
import wandb
args = parse_args()
run_name = f"{args.exp_name}_{args.seed}"
run = wandb.init(project="mapo",
entity="aadharna",
# sync_tensorboard=True,
config=vars(args),
name=run_name,
# monitor_gym=False,
save_code=True,
)
time.sleep(4)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
# device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
device = torch.device('cpu')
env = pyspiel.load_game('connect_four')
env = OpenSpielCompatibleSkipHandleEnv(env)
env = ss.flatten_v0(env)
env = ss.agent_indicator_v0(env, type_only=False)
envs = InfToIntWrapper(env)
envs.single_action_space = env.action_space('player_0')
envs.single_observation_space = env.observation_space('player_0')
envs.reset()
agent = Agent(envs).to(device)
greedy_agent = Agent(envs).to(device)
greedy_agent.load_state_dict(agent.state_dict())
random_agent = Agent(envs).to(device)
random_agent.load_state_dict(agent.state_dict())
agent_ids = ['main', 'main_v1']
novelty_policy_map = {"main": agent.state_dict(),
"main_v1": random_agent.state_dict()}
rating = OpenSkillRating(args.mu, args.anchor_mu, args.sigma)
rating.add_policy('main')
rating.add_policy('main_v1')
rating.set_anchor(name='main_v1')
n_opponents = 1
greedy_n_opponents = 1
novelty_archive = OrderedDict()
main_states_by_iter = OrderedDict()
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
greedy_optimizer = optim.Adam(greedy_agent.parameters(), lr=args.learning_rate, eps=1e-5)
win_rate = 0
greedy_win_rate = 0
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
envs.reset()
observation, _, _, _, _ = env.last()
next_obs = torch.from_numpy(observation).unsqueeze(0).float()
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
anneal_lr(optimizer=optimizer, update=update, num_updates=num_updates,
learning_rate=args.learning_rate)
(main_batch, opponent_batch, global_step, next_obs,
next_done, batch_reward_main, batch_reward_opp, batch_opponents_scores) = generate_data(
global_step, agent, random_agent, n_opponents, novelty_policy_map, args.num_steps, envs, device)
storage = {'main': main_batch, 'opponent': opponent_batch,
'next_obs': next_obs, 'next_dones': next_done}
save_policy_trigger = False
novelty_reward_value = 0
win_rate, action_stats = eval_agent_against_archive(agent, "main", random_agent, novelty_policy_map, rating,
args.n_eval_matches, envs, device)
print(f"Iter={update}/{num_updates}: win-rate={round(win_rate, 4)}")
loss, entropy_loss, pg_loss, v_loss, explained_var, approx_kl, meanclipfracs, old_approx_kl = update_model(
agent, "main", optimizer, envs, storage, args.num_steps, n_opt_steps=args.update_epochs,
minibatch_size=args.minibatch_size, gamma=args.gamma, gae_lambda=args.gae_lambda,
clip_coef=args.clip_coef, norm_adv=args.norm_adv, clip_vloss=args.clip_vloss,
max_grad_norm=args.max_grad_norm, ent_coef=args.ent_coef, vf_coef=args.vf_coef, target_kl=args.target_kl,
device=device
)
#### save data to wandb ####
# TRY NOT TO MODIFY: record rewards for plotting purposes
wandb.log({"charts/learning_rate": optimizer.param_groups[0]['lr']}, step=update)
wandb.log({"losses/value_loss": v_loss}, step=update)
wandb.log({"losses/policy_loss": pg_loss}, update)
wandb.log({"losses/entropy": entropy_loss}, update)
wandb.log({"losses/old_approx_kl": old_approx_kl}, update)
wandb.log({"losses/approx_kl": approx_kl}, update)
wandb.log({"losses/clipfrac": meanclipfracs}, update)
wandb.log({"losses/explained_variance": explained_var}, update)
# custom saves
wandb.log({"charts/puck/winrate": win_rate}, update)
wandb.log({"charts/puck/league_size": len(novelty_policy_map)}, update)
wandb.log({"charts/puck/novelty_score": novelty_reward_value}, update)
wandb.log({"charts/puck/mc_threshold": args.mc_threshold}, update)
wandb.log({"charts/puck/novelty_threshold": args.novelty_threshold}, update)
# save agents reward / std
wandb.log({"charts/puck/reward/main_mean": np.mean(batch_reward_main)}, update)
wandb.log({"charts/puck/reward/main_std": np.std(batch_reward_main)}, update)
wandb.log({"charts/puck/reward/main_max": np.max(batch_reward_main)}, update)
wandb.log({"charts/puck/reward/main_min": np.min(batch_reward_main)}, update)
wandb.log({"charts/puck/reward/oppo_mean": np.mean(batch_reward_opp)}, update)
wandb.log({"charts/puck/reward/oppo_std": np.std(batch_reward_opp)}, update)
# save probabilities for each action
for action_stat, prob in action_stats.items():
if len(prob) > 1:
continue
else:
prob = prob[0]
wandb.log({f"charts/puck/action_{action_stat}": prob}, step=update)
# save elo scores for all the agents
r = rating.ratings['main']
wandb.log({f"charts/elo/puck_v_novelty_archive": int(r.mu)}, step=update)
# save policies / archives
torch.save(agent.state_dict(), os.path.join(wandb.run.dir, "main_policy.pt"))
with open(os.path.join(wandb.run.dir, "archives.pkl"), "wb") as f:
pickle.dump({
"novelty_archive": novelty_archive,
"novelty_policy_map": novelty_policy_map,
}, f)
from typing import Optional
import numpy as np
from gymnasium.spaces import Box, Discrete
from shimmy.openspiel_compatibility import OpenSpielCompatibilityV0
try:
import pyspiel
from open_spiel.python.rl_environment import Environment
except ImportError:
raise Exception("Please install open_spiel python package.")
class OpenSpielCompatibleSkipHandleEnv(OpenSpielCompatibilityV0):
def __init__(self, env):
super().__init__(env)
def step(self, action):
try:
super().step(action)
except pyspiel.SpielError as e:
player_id = self.game_state.current_player()
if not self._end_routine() and player_id != -4:
random_action = np.random.choice(self.game_state.legal_actions(player_id))
self.game_state.apply_action(random_action)
else:
random_action = None
self.game_state.apply_action(random_action)
self._execute_chance_node()
self._update_action_masks()
self._update_observations()
self._update_rewards()
self._update_termination_truncation()
# pick the next agent
self._choose_next_agent()
# accumulate the rewards
self._accumulate_rewards()
class DSOpenSpielEnv(OpenSpielCompatibilityV0):
"""adds a dominant strategy to C4.
If the agent plays token in column 1, then 2, then 7, then the game ends and the agent wins"""
def __init__(self, env):
super().__init__(env)
self.dominant_strategy = [1, 2, 6]
self.dominant_strategy_index = 0
self.action_memory = []
def reset(self, *, seed=None, options=None):
super().reset(seed=seed, options=options)
self.dominant_strategy_index = 0
self.action_memory = []
return
def step(self, action):
# call super step
check_dom = False
if self.agent_selection == 'agent_0':
self.action_memory.append(action['agent_0'])
if len(self.action_memory) > 3:
self.action_memory.pop(0)
check_dom = True
super().step(action)
# check to see if the dominant strategy has been triggered by the 'main' agent
if check_dom:
if self.action_memory == self.dominant_strategy:
self.rewards = {a: r for a, r in zip(self.agents, self.game_state.rewards())}
self.rewards[self.agents[0]] = 1
self.rewards[self.agents[1]] = -1
self.terminations = {a: True for a in self.agents}
self.truncations = {a: True for a in self.agents}
# if __name__ == '__main__':
# import pyspiel
# from c4.c4_spiel_wrapper import OpenSpielEnv
# c4 = pyspiel.load_game('connect_four')
# env = OpenSpielEnv(c4)
# obs, _ = env.reset()
# print(obs)
# print(env.turn())
#
# while env.turn() != 'over':
# obs, rewards, dones, _, infos = env.step({env.turn(): 0})
#
# print(obs)
# print(rewards)
# print(env.turn())
import numpy as np
import torch
import torch.nn as nn
from typing import Optional
from collections import defaultdict
from .utils import RolloutData, rollout, _get_action_name_from_id
def anneal_lr(optimizer, update, num_updates, learning_rate):
"""
Anneals the learning rate linearly from args.learning_rate down to 0.
"""
frac = 1.0 - (update - 1.0) / num_updates
lrnow = frac * learning_rate
optimizer.param_groups[0]["lr"] = lrnow
def generate_data(global_step, learner, opponent, n_opponents,
archive, n_opponents_in_batch, env, device):
"""
Generates data for the learner to train on.
"""
main_batch = RolloutData(episode=[],
agent_id=[],
obs=torch.Tensor([]),
actions=torch.Tensor([]),
logprobs=torch.Tensor([]),
rewards=torch.Tensor([]),
dones=torch.Tensor([]),
next_obs=torch.Tensor([]),
values=torch.Tensor([]),
logits=torch.Tensor([])).to(device)
opponent_batch = RolloutData(episode=[],
agent_id=[],
obs=torch.Tensor([]),
actions=torch.Tensor([]),
logprobs=torch.Tensor([]),
rewards=torch.Tensor([]),
dones=torch.Tensor([]),
next_obs=torch.Tensor([]),
values=torch.Tensor([]),
logits=torch.Tensor([])).to(device)
batch_opponents_scores = defaultdict(list)
batch_reward_main = []
batch_reward_opp = []
for step in range(0, n_opponents_in_batch):
opponent_id = np.random.choice(range(1, n_opponents + 1))
opponent_weights = archive[f"main_v{opponent_id}"]
opponent.load_state_dict(opponent_weights)
# for j in range(0, args.num_steps // env_config.get("max_steps", 15)):
rollout_results = rollout(env, learner, opponent, f"main_v{opponent_id}", device)
# print("in generate data: ", rollout_results.values._version)
rollout_results.to(device)
global_step += rollout_results.rewards.shape[0]
main_batch.combine(rollout_results)
next_done = rollout_results.dones[-1]
next_obs = rollout_results.next_obs[-1]
opponent_batch.combine(rollout_results)
batch_opponents_scores[f"main_v{opponent_id}"].append(rollout_results.rewards.sum())
batch_reward_main.append(rollout_results.rewards.sum().item())
batch_reward_opp.append(batch_opponents_scores[f"main_v{opponent_id}"][-1].item())
return (main_batch, opponent_batch, global_step, next_obs, next_done,
batch_reward_main, batch_reward_opp, batch_opponents_scores)
def eval_agent_against_archive(learner, learner_id, opponent, archive, rating_object, n_eval_matches, env, device,
skip=None):
learner.eval()
scores = defaultdict(list)
action_stats = defaultdict(list)
if skip is None:
skip = ['main', 'greedy_main']
for opponent_id, opponent_policy in archive.items():
if opponent_id in skip:
continue
opponent.load_state_dict(opponent_policy)
# play 10 matches against each opponent and record the win rate and action stats
for _ in range(n_eval_matches):
result = rollout(env, learner, opponent, opponent_id, device)
scores[opponent_id].append(result.rewards.sum().item())
# update agent ranks
rating_object.update(policy_ids=[learner_id, opponent_id], scores=[result.rewards.sum().item(),
-result.rewards.sum().item()])
counters = {f"c{i}": 0 for i in range(env.single_action_space.n)}
# get mean logits for each action dim
logits = result.logits.mean(dim=0).cpu().clone().detach().numpy()
probs = np.exp(logits + 1e-6) / np.sum(np.exp(logits), keepdims=True)
for action, prob in enumerate(probs):
counters[_get_action_name_from_id(action)] = prob
for action_name, frequency in counters.items():
action_stats[action_name].append(frequency)
for i in range(env.single_action_space.n):
action_stats[f'{_get_action_name_from_id(i)}_logit'].append(logits[i])
means = {}
stds = {}
logits = {}
for action_name, frequency in action_stats.items():
means[action_name + '_mean'] = np.mean(frequency)
stds[action_name + '_std'] = np.std(frequency)
for m in means:
action_stats[m].append(means[m])
for s in stds:
action_stats[s].append(stds[s])
for l in logits:
action_stats[l].append(logits[l])
mean_scores = np.array([np.mean(np.array(v) > 0) for k, v in scores.items()])
win_rate = np.mean(mean_scores).item()
# import pdb; pdb.set_trace()
learner.train()
return win_rate, action_stats
def update_model(learner, learner_id, optimizer, env, data, n_steps, n_opt_steps, minibatch_size, gamma, gae_lambda,
clip_coef, norm_adv, clip_vloss, max_grad_norm, ent_coef, vf_coef, target_kl, device,
opponent=None, archive=None, ArchiveKL: Optional["ArchiveKL"]=None):
# get data for the model we want to update
rollouts = data[learner_id]
obs = rollouts.obs
actions = rollouts.actions
logprobs = rollouts.logprobs
rewards = rollouts.rewards
dones = rollouts.dones
values = rollouts.values
logits = rollouts.logits
next_obs = data['next_obs']
next_done = data['next_dones']
# print("extracted from storage pre GAE: ", values._version)
print("start of opt: ", learner.critic.critic_out._version)
# bootstrap value if not done
with torch.no_grad():
sort_list = []
advantages = []
returns = []
indices = torch.arange(0, rewards.shape[0]).long().to(device)
next_value = learner.get_value(torch.FloatTensor(next_obs).unsqueeze(0).to(device))
for player in ['player_0', 'player_1']:
mask = np.array(rollouts.agent_id) == player
masked_inds = list(indices[mask])
lastgaelam = 0
for t in reversed(range(mask.sum())):
if t == mask.sum() - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[mask][t + 1]
nextvalues = values[mask][t + 1]
delta = rewards[mask][t] + gamma * nextvalues * nextnonterminal - values[mask][t]
lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
advantages.append(lastgaelam)
sort_list.append(masked_inds[t])
returns.append(lastgaelam + values[mask][t])
advantages = torch.cat(advantages)[torch.LongTensor(sort_list)].to(device)
returns = torch.cat(returns)[torch.LongTensor(sort_list)].to(device)
# print("extracted from storage post GAE: ", values._version)
# print("return version: ", returns._version)
print("after gae: ", learner.critic.critic_out._version)
# flatten the batch
b_obs = obs.reshape((-1,) + env.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + env.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
b_logits = logits.reshape((-1,) + (env.single_action_space.n,))
# print("flattened: ", b_values._version)
# Optimizing the policy and value network
batch_size = b_obs.shape[0]
b_inds = np.arange(batch_size)
clipfracs = []
for epoch in range(n_opt_steps):
np.random.shuffle(b_inds)
for start in range(0, batch_size, minibatch_size):
end = start + minibatch_size
mb_inds = b_inds[start:end]
_, newlogprob, entropy, newvalue, _ = learner.get_action_and_value(b_obs[mb_inds].unsqueeze(1),
b_actions.long()[mb_inds])
print("getting new log probs: ", learner.critic.critic_out._version)
logits_archive = []
if ArchiveKL is not None:
# get logits for observations from each agent in the novelty archive
for policy_id, policy in archive.items():
if policy_id in ["main", "greedy_main"]:
continue
# load policy into random agent
opponent.load_state_dict(policy)
logits_archive.append(opponent.get_action_and_value(b_obs[mb_inds].unsqueeze(1),
b_actions.long()[mb_inds])[4])
logits_archive = torch.stack(logits_archive, dim=0).detach()
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]
mb_advantages = b_advantages[mb_inds]
if norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss
newvalue = newvalue.view(-1)
if clip_vloss:
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
-clip_coef,
clip_coef,
)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
# Entropy loss
entropy_loss = entropy.mean()
# Novelty loss
# try to maximize the KL divergence between the current policy and the archive
novelty_loss = 0
if ArchiveKL is not None:
novelty_loss = ArchiveKL.forward(b_logits[mb_inds].unsqueeze(-1), logits_archive).mean()
# Total loss
loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef + novelty_loss * 0.01
optimizer.zero_grad()
print("pre backward: ", learner.critic.critic_out._version)
loss.backward(retain_graph=True)
print("post backward: ", learner.critic.critic_out._version)
nn.utils.clip_grad_norm_(learner.parameters(), max_grad_norm)
print("post clip pre step: ", learner.critic.critic_out._version)
optimizer.step()
print("post step: ", learner.critic.critic_out._version)
if target_kl is not None:
if approx_kl > target_kl:
break
# print("b_returns final version: ", b_returns._version)
y_pred, y_true = b_values.cpu().detach().numpy(), b_returns.cpu().detach().numpy()
var_y = np.var(y_true)
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
return (loss.item(), entropy_loss.item(), pg_loss.item(), v_loss.item(),
explained_var, approx_kl.item(), np.mean(clipfracs), old_approx_kl.item()
)
import numpy as np
from collections import defaultdict
import openskill
def rank(policy_ids, scores):
'''Compute policy rankings from per-agent scores'''
agents = defaultdict(list)
for policy_id, score in zip(policy_ids, scores):
agents[policy_id].append(score)
# Double argsort returns ranks
return np.argsort(np.argsort(
[-np.mean(vals) + 1e-8 * np.random.normal() for policy, vals in
sorted(agents.items())])).tolist()
class OpenSkillRating:
'''OpenSkill Rating wrapper for estimating relative policy skill
Provides a simple method for updating skill estimates from raw
per-agent scores as are typically returned by the environment.'''
def __init__(self, mu, anchor_mu, sigma, agents=[], anchor=None):
'''
Args:
agents: List of agent classes to rank
anchor: Baseline policy name to anchor to mu
mu: Anchor point for the baseline policy (cannot be exactly 0)
sigma: 68/95/99.7 win rate against 1/2/3 sigma lower SR'''
if __debug__:
err = 'Agents must be ordered (e.g. list, not set)'
assert type(agents) != set, err
self.ratings = {}
self.mu = mu
self.anchor_mu = anchor_mu
self.sigma = sigma
for e in agents:
self.add_policy(e)
self.anchor = anchor
self._anchor_baseline()
def __str__(self):
return ', '.join(f'{p}: {int(r.mu)}' for p, r in self.ratings.items())
@property
def stats(self):
return {p: int(r.mu) for p, r in self.ratings.items()}
def _anchor_baseline(self):
'''Resets the anchor point policy to mu SR'''
for agent, rating in self.ratings.items():
rating.sigma = self.sigma
if agent == self.anchor:
rating.mu = self.anchor_mu
rating.sigma = self.sigma
def set_anchor(self, name):
'''TODO: multiple anchors'''
if self.anchor is not None:
self.remove_policy(self.anchor)
if name not in self.ratings:
self.add_policy(name)
self.anchor = name
self._anchor_baseline()
def add_policy(self, name, mu=None, sigma=None):
assert name not in self.ratings, f'Policy {name} already added to ratings'
if mu is not None and sigma is not None:
self.ratings[name] = openskill.Rating(mu=mu, sigma=sigma)
else:
self.ratings[name] = openskill.Rating(mu=self.mu, sigma=self.sigma)
def remove_policy(self, name):
assert name in self.ratings, f'Policy {name} not in ratings'
del self.ratings[name]
def update(self, policy_ids, ranks=None, scores=None):
'''Updates internal skill rating estimates for each policy
You should call this function once per simulated environment
Provide either ranks OR policy_ids and scores
Args:
ranks: List of ranks in the same order as agents
policy_ids: List of policy IDs for each agent episode
scores: List of scores for each agent episode
Returns:
Dictionary of ratings keyed by agent names'''
for pol_id in policy_ids:
assert pol_id in policy_ids, f'policy_id {pol_id} not in policy_ids'
if __debug__:
err = 'Specify either ranks or scores'
assert (ranks is None) != (scores is None), err
assert self.anchor is not None, 'Set the anchor policy before updating ratings'
if ranks is None:
ranks = rank(policy_ids, scores)
teams = [[self.ratings[e]] for e in policy_ids]
ratings = openskill.rate(teams, rank=ranks)
#ratings = [openskill.create_rating(team[0]) for team in ratings]
for agent, rating in zip(policy_ids, ratings):
self.ratings[agent] = rating[0]
self._anchor_baseline()
#
# return self.ratings
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import itertools
import numpy as np
from uuid import uuid4
from dataclasses import dataclass
@dataclass
class RolloutData:
episode: list
agent_id: list
obs: torch.Tensor
actions: torch.Tensor
rewards: torch.Tensor
dones: torch.Tensor
next_obs: torch.Tensor
logprobs: torch.Tensor
values: torch.Tensor
logits: torch.Tensor
def to(self, device):
self.obs = self.obs.to(device)
self.actions = self.actions.to(device)
self.rewards = self.rewards.to(device)
self.dones = self.dones.to(device)
self.next_obs = self.next_obs.to(device)
self.logprobs = self.logprobs.to(device)
self.values = self.values.to(device)
self.logits = self.logits.to(device)
return self
def to_dict(self):
return {
"obs": self.obs,
"actions": self.actions,
"rewards": self.rewards,
"dones": self.dones,
"next_obs": self.next_obs,
"logprobs": self.logprobs,
"values": self.values,
'episode': self.episode,
'agent_id': self.agent_id,
'logits': self.logits
}
def combine(self, other):
self.obs = torch.cat((self.obs, other.obs))
self.actions = torch.cat((self.actions, other.actions))
self.rewards = torch.cat((self.rewards, other.rewards))
self.dones = torch.cat((self.dones, other.dones))
self.next_obs = torch.cat((self.next_obs, other.next_obs))
self.logprobs = torch.cat((self.logprobs, other.logprobs))
self.values = torch.cat((self.values, other.values))
self.logits = torch.cat((self.logits, other.logits))
self.episode = self.episode + other.episode
self.agent_id = self.agent_id + other.agent_id
def split_by_episode(self):
episode_ids = list(set(self.episode))
episode_rollouts = {}
for episode_id in episode_ids:
mask = np.array(self.episode) == episode_id
episode_rollouts[episode_id] = RolloutData(
episode=list(itertools.compress(self.episode, mask)),
agent_id=list(itertools.compress(self.agent_id, mask)),
obs=self.obs[mask],
actions=self.actions[mask],
rewards=self.rewards[mask],
dones=self.dones[mask],
next_obs=self.next_obs[mask],
logprobs=self.logprobs[mask],
values=self.values[mask],
logits=self.logits[mask],
)
return episode_rollouts
def append_obs(self, obs):
self.obs = torch.cat((self.obs, obs))
def append_actions(self, actions):
self.actions = torch.cat((self.actions, actions))
def append_rewards(self, rewards):
self.rewards = torch.cat((self.rewards, rewards))
def append_dones(self, dones):
self.dones = torch.cat((self.dones, dones))
def append_next_obs(self, next_obs):
self.next_obs = torch.cat((self.next_obs, next_obs))
def append_logprobs(self, logprobs):
self.logprobs = torch.cat((self.logprobs, logprobs))
def append_values(self, values):
self.values = torch.cat((self.values, values))
def append_logits(self, logits):
self.logits = torch.cat((self.logits, logits))
def append_episode(self, episode):
self.episode = self.episode + episode
def append_agent_id(self, agent_id):
self.agent_id = self.agent_id + agent_id
def rollout(env, main_policy, opponent_policy, opponent_name, device):
# todo rewrite this using the open spiel rl enviornment API
# it will solve all these issues I've been having with the rollout
main_policy.to(device)
# opponent_policy.to(device)
bit_size = 64
episode_id = uuid4().int >> bit_size
data = RolloutData([],
[],
torch.Tensor([]),
torch.Tensor([]),
torch.Tensor([]),
torch.Tensor([]),
torch.Tensor([]),
torch.Tensor([]),
torch.Tensor([]),
torch.Tensor([]))
env.reset()
for agent_name in env.agent_iter():
observation, reward, termination, truncation, info = env.last()
observation = torch.from_numpy(observation).unsqueeze(0).float()
if termination or truncation:
action = None
else:
action, logprob, entropy, value, logits = main_policy.get_action_and_value(observation)
# print("original value object in rollout: ", value._version)
env.step(action)
if termination or truncation:
action = None
else:
next_observation, reward, termination, truncation, info = env.last()
data.append_actions(action)
data.append_rewards(torch.Tensor([reward]))
data.append_values(value)
data.append_obs(observation)
data.append_dones(torch.Tensor([termination]))
data.append_next_obs(torch.from_numpy(next_observation).unsqueeze(0).float())
data.append_logprobs(logprob)
data.append_logits(logits)
data.append_episode([episode_id])
data.append_agent_id([agent_name])
# print("at the end of the rollout: ", data.values._version)
return data
def _get_action_name_from_id(action_id):
action_name = None
if action_id == 0:
action_name = "c0"
elif action_id == 1:
action_name = "c1"
elif action_id == 2:
action_name = "c2"
elif action_id == 3:
action_name = "c3"
elif action_id == 4:
action_name = "c4"
elif action_id == 5:
action_name = "c5"
elif action_id == 6:
action_name = "c6"
return action_name
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(envs.observation_space.shape[0], 64)),
nn.ReLU(),
layer_init(nn.Linear(64, 64)),
nn.ReLU(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor = nn.Sequential(
layer_init(nn.Linear(envs.observation_space.shape[0], 64)),
nn.ReLU(),
layer_init(nn.Linear(64, 64)),
nn.ReLU(),
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
)
def get_value(self, x):
return self.critic(x)
def get_action_and_value(self, x, action=None):
if x.ndim == 1:
x = x.unsqueeze(0)
logits = self.actor(x)
# todo -- determine what to do if nans appear in logits
# ensure logits are not nan
# logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x), logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment