Skip to content

Instantly share code, notes, and snippets.

@TadaoYamaoka
Created October 8, 2023 12:22
Show Gist options
  • Save TadaoYamaoka/78aec6d73dfc8a5c60f330385f73a569 to your computer and use it in GitHub Desktop.
Save TadaoYamaoka/78aec6d73dfc8a5c60f330385f73a569 to your computer and use it in GitHub Desktop.
PyTorchによるPPO実装
from typing import NamedTuple
import gymnasium as gym
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda")
n_envs = 8
total_timesteps = 1_000_000
# PPO Parameter
learning_rate = 3e-4
n_rollout_steps = 1024
batch_size = 64
n_epochs = 10
gamma = 0.99
gae_lambda = 0.95
clip_range = 0.2
normalize_advantage = True
ent_coef = 0.0
vf_coef = 0.5
max_grad_norm = 0.5
buffer_size = n_envs * n_rollout_steps
vec_env = [gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array") for _ in range(n_envs)]
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((84, 84), antialias=None), transforms.Grayscale()])
last_obs = torch.empty((n_envs, 4, 84, 84), dtype=torch.float32)
for i, env in enumerate(vec_env):
obs, _, = env.reset()
last_obs[i, :] = transform(obs)
episode_frame_numbers = []
episode_rewards = []
vec_env_reward = [0 for _ in range(n_envs)]
def on_rollout_start():
episode_frame_numbers.clear()
episode_rewards.clear()
def step(vec_env, actions):
buf_obs = torch.empty((n_envs, 4, 84, 84), dtype=torch.float32)
buf_rews = torch.zeros((n_envs,), dtype=torch.float32)
buf_done = torch.zeros((n_envs,), dtype=torch.bool)
for i in range(n_envs):
for j in range(4):
obs, rew, terminated, truncated, info = vec_env[i].step(actions[i])
buf_rews[i] += rew
vec_env_reward[i] += rew
if terminated or truncated:
buf_done[i] = True
obs, _, = vec_env[i].reset()
buf_obs[i, :] = transform(obs)
episode_frame_numbers.append(info["episode_frame_number"])
episode_rewards.append(vec_env_reward[i])
vec_env_reward[i] = 0
break
buf_obs[i, j] = transform(obs)
return buf_obs, buf_rews, buf_done
class RolloutBufferSamples(NamedTuple):
observations: torch.Tensor
actions: torch.Tensor
old_values: torch.Tensor
old_log_prob: torch.Tensor
advantages: torch.Tensor
returns: torch.Tensor
class RolloutBuffer:
def __init__(self, buffer_size, n_envs, obs_shape, action_dim, device):
self.buffer_size = buffer_size
self.n_envs = n_envs
self.obs_shape = obs_shape
self.action_dim = action_dim
self.device = device
def reset(self):
self.observations = torch.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=torch.float32)
self.actions = torch.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=torch.int64)
self.rewards = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
self.returns = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
self.episode_starts = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
self.values = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
self.log_probs = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
self.advantages = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32)
self.pos = 0
self.generator_ready = False
def add(self, obs, action, reward, episode_start, value, log_prob):
self.observations[self.pos] = obs
self.actions[self.pos] = action
self.rewards[self.pos] = reward
self.episode_starts[self.pos] = episode_start
self.values[self.pos] = value.cpu().flatten()
self.log_probs[self.pos] = log_prob.cpu()
self.pos += 1
def compute_returns_and_advantage(self, last_values, dones):
last_values = last_values.cpu().flatten()
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones.to(torch.float32)
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
self.returns = self.advantages + self.values
@staticmethod
def swap_and_flatten(arr):
shape = arr.shape
if len(shape) < 3:
shape = (*shape, 1)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def get(self, batch_size):
indices = np.random.permutation(self.buffer_size * self.n_envs)
if not self.generator_ready:
self.observations = self.swap_and_flatten(self.observations)
self.actions = self.swap_and_flatten(self.actions)
self.values = self.swap_and_flatten(self.values)
self.log_probs = self.swap_and_flatten(self.log_probs)
self.advantages = self.swap_and_flatten(self.advantages)
self.returns = self.swap_and_flatten(self.returns)
self.generator_ready = True
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def to_torch(self, array):
return torch.as_tensor(array, device=self.device)
def _get_samples(
self,
batch_inds
):
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
class PolicyValueNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc = nn.Linear(3136, 512)
self.fc_p = nn.Linear(512, 4)
self.fc_v = nn.Linear(512, 1)
def extract_features(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.fc(x.flatten(1)))
return x
def forward(self, x):
x = self.extract_features(x)
policy = F.relu(self.fc_p(x))
value = F.relu(self.fc_v(x))
return policy, value
def predict_values(self, x):
x = self.extract_features(x)
value = F.relu(self.fc_v(x))
return value
@staticmethod
def log_prob(value, logits):
value, log_pmf = torch.broadcast_tensors(value, logits)
value = value[..., :1]
log_prob = log_pmf.gather(-1, value).squeeze(-1)
return log_prob
@staticmethod
def entropy(logits):
min_real = torch.finfo(logits.dtype).min
logits = torch.clamp(logits, min=min_real)
probs = F.softmax(logits, dim=-1)
p_log_p = logits * probs
return -p_log_p.sum(-1)
def sample(self, obs):
logits, values = self.forward(obs)
# Normalize
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
probs = F.softmax(logits, dim=-1)
actions = torch.multinomial(probs, 1, True)
return actions, values, self.log_prob(actions, logits)
def evaluate_actions(self, obs, actions):
logits, values = self.forward(obs)
# Normalize
logits = logits - logits.logsumexp(dim=-1, keepdim=True)
log_prob = self.log_prob(actions, logits)
entropy = self.entropy(logits)
return values, log_prob, entropy
model = PolicyValueNet()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-5)
rollout_buffer = RolloutBuffer(n_rollout_steps, n_envs, (4, 84, 84), 1, device)
iteration = 0
num_timesteps = 0
global_step = 0
while num_timesteps < total_timesteps:
last_episode_starts = torch.ones((n_envs,), dtype=torch.bool)
# collect_rollouts
model.eval()
n_steps = 0
rollout_buffer.reset()
on_rollout_start()
while n_steps < n_rollout_steps:
with torch.no_grad():
obs_tensor = last_obs.to(device)
actions, values, log_probs = model.sample(obs_tensor)
actions = actions.cpu()
new_obs, rewards, dones = step(vec_env, actions.reshape(-1))
num_timesteps += n_envs
n_steps += 1
rollout_buffer.add(
last_obs,
actions,
rewards,
last_episode_starts,
values,
log_probs,
)
last_obs = new_obs
last_episode_starts = dones
with torch.no_grad():
# Compute value for the last timestep
values = model.predict_values(new_obs.to(device))
# Compute the lambda-return (TD(lambda) estimate) and GAE(lambda) advantage
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
iteration += 1
# Logs
writer.add_scalar("rollout/ep_len_mean", sum(episode_frame_numbers) / len(episode_frame_numbers), global_step)
writer.add_scalar("rollout/ep_rew_mean", sum(episode_rewards) / len(episode_rewards), global_step)
# train
model.train()
for epoch in range(n_epochs):
for rollout_data in rollout_buffer.get(batch_size):
actions = rollout_data.actions
values, log_prob, entropy = model.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = torch.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values)
# Entropy loss favor exploration
entropy_loss = -torch.mean(entropy)
loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss
# Optimization step
optimizer.zero_grad()
loss.backward()
# Clip grad norm
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
# Logs
writer.add_scalar("train/policy_loss", policy_loss.item(), global_step)
writer.add_scalar("train/value_loss", value_loss.item(), global_step)
writer.add_scalar("train/entropy_loss", entropy_loss.item(), global_step)
writer.add_scalar("train/loss", loss.item(), global_step)
global_step += 1
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, "checkpoint.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment