Created
October 8, 2023 12:22
-
-
Save TadaoYamaoka/78aec6d73dfc8a5c60f330385f73a569 to your computer and use it in GitHub Desktop.
PyTorchによるPPO実装
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import NamedTuple | |
import gymnasium as gym | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from torch.utils.tensorboard import SummaryWriter | |
writer = SummaryWriter() | |
device = torch.device("cuda") | |
n_envs = 8 | |
total_timesteps = 1_000_000 | |
# PPO Parameter | |
learning_rate = 3e-4 | |
n_rollout_steps = 1024 | |
batch_size = 64 | |
n_epochs = 10 | |
gamma = 0.99 | |
gae_lambda = 0.95 | |
clip_range = 0.2 | |
normalize_advantage = True | |
ent_coef = 0.0 | |
vf_coef = 0.5 | |
max_grad_norm = 0.5 | |
buffer_size = n_envs * n_rollout_steps | |
vec_env = [gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array") for _ in range(n_envs)] | |
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((84, 84), antialias=None), transforms.Grayscale()]) | |
last_obs = torch.empty((n_envs, 4, 84, 84), dtype=torch.float32) | |
for i, env in enumerate(vec_env): | |
obs, _, = env.reset() | |
last_obs[i, :] = transform(obs) | |
episode_frame_numbers = [] | |
episode_rewards = [] | |
vec_env_reward = [0 for _ in range(n_envs)] | |
def on_rollout_start(): | |
episode_frame_numbers.clear() | |
episode_rewards.clear() | |
def step(vec_env, actions): | |
buf_obs = torch.empty((n_envs, 4, 84, 84), dtype=torch.float32) | |
buf_rews = torch.zeros((n_envs,), dtype=torch.float32) | |
buf_done = torch.zeros((n_envs,), dtype=torch.bool) | |
for i in range(n_envs): | |
for j in range(4): | |
obs, rew, terminated, truncated, info = vec_env[i].step(actions[i]) | |
buf_rews[i] += rew | |
vec_env_reward[i] += rew | |
if terminated or truncated: | |
buf_done[i] = True | |
obs, _, = vec_env[i].reset() | |
buf_obs[i, :] = transform(obs) | |
episode_frame_numbers.append(info["episode_frame_number"]) | |
episode_rewards.append(vec_env_reward[i]) | |
vec_env_reward[i] = 0 | |
break | |
buf_obs[i, j] = transform(obs) | |
return buf_obs, buf_rews, buf_done | |
class RolloutBufferSamples(NamedTuple): | |
observations: torch.Tensor | |
actions: torch.Tensor | |
old_values: torch.Tensor | |
old_log_prob: torch.Tensor | |
advantages: torch.Tensor | |
returns: torch.Tensor | |
class RolloutBuffer: | |
def __init__(self, buffer_size, n_envs, obs_shape, action_dim, device): | |
self.buffer_size = buffer_size | |
self.n_envs = n_envs | |
self.obs_shape = obs_shape | |
self.action_dim = action_dim | |
self.device = device | |
def reset(self): | |
self.observations = torch.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=torch.float32) | |
self.actions = torch.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=torch.int64) | |
self.rewards = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32) | |
self.returns = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32) | |
self.episode_starts = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32) | |
self.values = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32) | |
self.log_probs = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32) | |
self.advantages = torch.zeros((self.buffer_size, self.n_envs), dtype=torch.float32) | |
self.pos = 0 | |
self.generator_ready = False | |
def add(self, obs, action, reward, episode_start, value, log_prob): | |
self.observations[self.pos] = obs | |
self.actions[self.pos] = action | |
self.rewards[self.pos] = reward | |
self.episode_starts[self.pos] = episode_start | |
self.values[self.pos] = value.cpu().flatten() | |
self.log_probs[self.pos] = log_prob.cpu() | |
self.pos += 1 | |
def compute_returns_and_advantage(self, last_values, dones): | |
last_values = last_values.cpu().flatten() | |
last_gae_lam = 0 | |
for step in reversed(range(self.buffer_size)): | |
if step == self.buffer_size - 1: | |
next_non_terminal = 1.0 - dones.to(torch.float32) | |
next_values = last_values | |
else: | |
next_non_terminal = 1.0 - self.episode_starts[step + 1] | |
next_values = self.values[step + 1] | |
delta = self.rewards[step] + gamma * next_values * next_non_terminal - self.values[step] | |
last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam | |
self.advantages[step] = last_gae_lam | |
self.returns = self.advantages + self.values | |
@staticmethod | |
def swap_and_flatten(arr): | |
shape = arr.shape | |
if len(shape) < 3: | |
shape = (*shape, 1) | |
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) | |
def get(self, batch_size): | |
indices = np.random.permutation(self.buffer_size * self.n_envs) | |
if not self.generator_ready: | |
self.observations = self.swap_and_flatten(self.observations) | |
self.actions = self.swap_and_flatten(self.actions) | |
self.values = self.swap_and_flatten(self.values) | |
self.log_probs = self.swap_and_flatten(self.log_probs) | |
self.advantages = self.swap_and_flatten(self.advantages) | |
self.returns = self.swap_and_flatten(self.returns) | |
self.generator_ready = True | |
start_idx = 0 | |
while start_idx < self.buffer_size * self.n_envs: | |
yield self._get_samples(indices[start_idx : start_idx + batch_size]) | |
start_idx += batch_size | |
def to_torch(self, array): | |
return torch.as_tensor(array, device=self.device) | |
def _get_samples( | |
self, | |
batch_inds | |
): | |
data = ( | |
self.observations[batch_inds], | |
self.actions[batch_inds], | |
self.values[batch_inds].flatten(), | |
self.log_probs[batch_inds].flatten(), | |
self.advantages[batch_inds].flatten(), | |
self.returns[batch_inds].flatten(), | |
) | |
return RolloutBufferSamples(*tuple(map(self.to_torch, data))) | |
class PolicyValueNet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) | |
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) | |
self.fc = nn.Linear(3136, 512) | |
self.fc_p = nn.Linear(512, 4) | |
self.fc_v = nn.Linear(512, 1) | |
def extract_features(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
x = F.relu(self.fc(x.flatten(1))) | |
return x | |
def forward(self, x): | |
x = self.extract_features(x) | |
policy = F.relu(self.fc_p(x)) | |
value = F.relu(self.fc_v(x)) | |
return policy, value | |
def predict_values(self, x): | |
x = self.extract_features(x) | |
value = F.relu(self.fc_v(x)) | |
return value | |
@staticmethod | |
def log_prob(value, logits): | |
value, log_pmf = torch.broadcast_tensors(value, logits) | |
value = value[..., :1] | |
log_prob = log_pmf.gather(-1, value).squeeze(-1) | |
return log_prob | |
@staticmethod | |
def entropy(logits): | |
min_real = torch.finfo(logits.dtype).min | |
logits = torch.clamp(logits, min=min_real) | |
probs = F.softmax(logits, dim=-1) | |
p_log_p = logits * probs | |
return -p_log_p.sum(-1) | |
def sample(self, obs): | |
logits, values = self.forward(obs) | |
# Normalize | |
logits = logits - logits.logsumexp(dim=-1, keepdim=True) | |
probs = F.softmax(logits, dim=-1) | |
actions = torch.multinomial(probs, 1, True) | |
return actions, values, self.log_prob(actions, logits) | |
def evaluate_actions(self, obs, actions): | |
logits, values = self.forward(obs) | |
# Normalize | |
logits = logits - logits.logsumexp(dim=-1, keepdim=True) | |
log_prob = self.log_prob(actions, logits) | |
entropy = self.entropy(logits) | |
return values, log_prob, entropy | |
model = PolicyValueNet() | |
model.to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-5) | |
rollout_buffer = RolloutBuffer(n_rollout_steps, n_envs, (4, 84, 84), 1, device) | |
iteration = 0 | |
num_timesteps = 0 | |
global_step = 0 | |
while num_timesteps < total_timesteps: | |
last_episode_starts = torch.ones((n_envs,), dtype=torch.bool) | |
# collect_rollouts | |
model.eval() | |
n_steps = 0 | |
rollout_buffer.reset() | |
on_rollout_start() | |
while n_steps < n_rollout_steps: | |
with torch.no_grad(): | |
obs_tensor = last_obs.to(device) | |
actions, values, log_probs = model.sample(obs_tensor) | |
actions = actions.cpu() | |
new_obs, rewards, dones = step(vec_env, actions.reshape(-1)) | |
num_timesteps += n_envs | |
n_steps += 1 | |
rollout_buffer.add( | |
last_obs, | |
actions, | |
rewards, | |
last_episode_starts, | |
values, | |
log_probs, | |
) | |
last_obs = new_obs | |
last_episode_starts = dones | |
with torch.no_grad(): | |
# Compute value for the last timestep | |
values = model.predict_values(new_obs.to(device)) | |
# Compute the lambda-return (TD(lambda) estimate) and GAE(lambda) advantage | |
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) | |
iteration += 1 | |
# Logs | |
writer.add_scalar("rollout/ep_len_mean", sum(episode_frame_numbers) / len(episode_frame_numbers), global_step) | |
writer.add_scalar("rollout/ep_rew_mean", sum(episode_rewards) / len(episode_rewards), global_step) | |
# train | |
model.train() | |
for epoch in range(n_epochs): | |
for rollout_data in rollout_buffer.get(batch_size): | |
actions = rollout_data.actions | |
values, log_prob, entropy = model.evaluate_actions(rollout_data.observations, actions) | |
values = values.flatten() | |
# Normalize advantage | |
advantages = rollout_data.advantages | |
if normalize_advantage: | |
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
# ratio between old and new policy, should be one at the first iteration | |
ratio = torch.exp(log_prob - rollout_data.old_log_prob) | |
# clipped surrogate loss | |
policy_loss_1 = advantages * ratio | |
policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range) | |
policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() | |
# Value loss using the TD(gae_lambda) target | |
value_loss = F.mse_loss(rollout_data.returns, values) | |
# Entropy loss favor exploration | |
entropy_loss = -torch.mean(entropy) | |
loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss | |
# Optimization step | |
optimizer.zero_grad() | |
loss.backward() | |
# Clip grad norm | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) | |
optimizer.step() | |
# Logs | |
writer.add_scalar("train/policy_loss", policy_loss.item(), global_step) | |
writer.add_scalar("train/value_loss", value_loss.item(), global_step) | |
writer.add_scalar("train/entropy_loss", entropy_loss.item(), global_step) | |
writer.add_scalar("train/loss", loss.item(), global_step) | |
global_step += 1 | |
checkpoint = { | |
'model': model.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
} | |
torch.save(checkpoint, "checkpoint.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment