-
-
Save kareemn/c760f652ffb62cb565f99f926d807025 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import gym | |
import matplotlib.pyplot as plt | |
import torch.nn as nn | |
import torch.optim as optim | |
import numpy as np | |
from torch.distributions import Categorical | |
import torch.nn.functional as F | |
import random | |
from tensorboardX import SummaryWriter | |
import torch.nn.functional as F | |
import os | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--checkpoint_path', type=str, default=None, | |
help="path of checkpoint pt file") | |
args = parser.parse_args() | |
env = gym.make('PongNoFrameskip-v4') | |
fig = plt.figure() | |
ax1 = fig.add_subplot(1,1,1) | |
xs = [] | |
ys = [] | |
yt = [] | |
gamma = 0.99 | |
tb_writer = SummaryWriter() | |
class ActorCritic(nn.Module): | |
def __init__(self): | |
super(ActorCritic, self).__init__() | |
self.conv1 = torch.nn.Conv2d(3*4, 32, 5, stride=3) | |
self.conv2 = torch.nn.Conv2d(32, 16, 5, stride=3) | |
self.conv3 = torch.nn.Conv2d(16, 8, 3, stride=1) | |
self.conv4 = torch.nn.Conv2d(8, 1, 3, stride=1) | |
self.hidden = torch.nn.Linear(216, 128) | |
self.critic_linear = torch.nn.Linear(128, 1) | |
self.actor_linear = torch.nn.Linear(128, env.action_space.n) | |
def forward(self, x): | |
x = F.leaky_relu(self.conv1(x)) | |
x = F.leaky_relu(self.conv2(x)) | |
x = F.leaky_relu(self.conv3(x)) | |
x = F.leaky_relu(self.conv4(x)) | |
a = x.flatten(start_dim=1) | |
b = self.hidden(a) | |
value = self.critic_linear(b) | |
policy = self.actor_linear(b) | |
return torch.squeeze(value), policy | |
actor_critic = ActorCritic().cuda() | |
if args.checkpoint_path is not None: | |
print("Resuming from checkpoint: %s" % args.checkpoint_path) | |
checkpoint = torch.load(args.checkpoint_path) | |
actor_critic.load_state_dict(checkpoint) | |
criterion = nn.MSELoss() | |
optimizer = optim.Adam(actor_critic.parameters(), lr=1e-4) | |
loss_step = 0 | |
experience = [] | |
sum_policy_loss = 0.0 | |
sum_entropy_loss = 0.0 | |
sum_policy_count = 1.0 | |
policy_experience = [] | |
sum_loss = 0.0 | |
sum_value_loss = 0.0 | |
sum_value_count = 0.0 | |
torch.autograd.set_detect_anomaly(True) | |
mean = None | |
std = None | |
for i_episode in range(500000): | |
epsilon = 1. - (i_episode / 30.0) | |
observation = env.reset() | |
episode_entropy = 0.0 | |
last_four = [] | |
for t in range(100000): | |
env.render() | |
# Get the current image frame. | |
s = torch.from_numpy(observation).float() | |
s = s.cuda().unsqueeze(0) | |
s = s.permute(0, 3, 1, 2) | |
if mean is None: | |
mean = s.mean() | |
if std is None: | |
std = s.std() | |
# Normalize the frame (sphereing our image) | |
s = (s - mean) / (std + np.finfo(np.float32).eps) | |
if len(last_four) == 4: | |
last_four.pop(0) | |
last_four.append(s) | |
else: | |
last_four.append(s) | |
last_four.append(s) | |
last_four.append(s) | |
last_four.append(s) | |
# Concat the previous four frames together. | |
s = torch.cat(last_four, dim=1) | |
# Get a probability distribution for actions we shoulud take | |
estimated_value, action_logits = actor_critic(s) | |
action_logits = action_logits.squeeze(0) | |
prob = F.softmax(action_logits, dim=-1) | |
print(f"\r{estimated_value.item():1.1} \r", end='') | |
dist = Categorical(prob) | |
# Sample the distribution. | |
maxa_tensor = dist.sample() | |
maxa = maxa_tensor.item() | |
episode_entropy += dist.entropy().mean() | |
log_prob = dist.log_prob(maxa_tensor) | |
entropy = dist.entropy().mean() | |
# print(f"policy selected {maxa}") | |
# Take the action. | |
observation, reward, done, info = env.step(maxa) | |
# Save our results in our experience buffer. | |
experience.append((s, torch.FloatTensor([reward]).cuda(), log_prob, estimated_value, maxa_tensor)) | |
if done: | |
print(f"learning from {t} frames") | |
sum_reward = 0.0 | |
# Discount all the rewards. | |
total_reward = 0.0 | |
total_discounted_reward = 0.0 | |
total_advantage = 0.0 | |
for i in reversed(range(len(experience))): | |
s, reward, b_prob, estimated_value, maxa_tensor = experience[i] | |
total_reward += reward | |
sum_reward = gamma*sum_reward + reward | |
total_discounted_reward += sum_reward | |
experience[i] = (s, sum_reward, b_prob, estimated_value, maxa_tensor) | |
# Shuffle the experience | |
random.shuffle(experience) | |
for i in experience: | |
s, reward, on_policy_log_prob, prev_estimated_value, maxa_tensor = i | |
estimated_value, action_logits = actor_critic(s) | |
action_logits = action_logits.squeeze(0) | |
prob = F.softmax(action_logits, dim=-1) | |
dist = Categorical(prob) | |
log_prob = dist.log_prob(maxa_tensor) | |
# Since the model lis changing as we work our way through the experience replay, | |
# this is technically an off-policy learning process, so we much do importance sampling. | |
log_prob -= on_policy_log_prob.detach() | |
dist = Categorical(prob) | |
entropy = dist.entropy().mean() | |
optimizer.zero_grad() | |
advantage = reward.unsqueeze(0) - estimated_value.unsqueeze(0) | |
value_loss = criterion(reward, estimated_value) | |
policy_loss = (-log_prob * advantage.detach()).mean() - 0.02 * entropy | |
sum_policy_loss += (policy_loss.item() + 0.02 * entropy) | |
sum_entropy_loss -= 0.02 * entropy | |
loss = policy_loss + 0.5 * value_loss | |
loss.backward() | |
optimizer.step() | |
sum_value_loss += value_loss.item() | |
sum_loss += loss.item() | |
sum_value_count += 1.0 | |
total_advantage += advantage.item() | |
del s | |
del estimated_value | |
del action_logits | |
del log_prob | |
torch.cuda.empty_cache() | |
# Write some values to tensorboard | |
tb_writer.add_scalars('loss', | |
{ | |
"loss": sum_loss/sum_value_count, | |
"value_loss": sum_value_loss/sum_value_count, | |
"policy_loss": sum_policy_loss/sum_value_count, | |
"entropy_loss": sum_entropy_loss/sum_value_count, | |
"total_reward": total_reward, | |
"total_discounted_reward": total_discounted_reward, | |
"toatl_advantage": total_advantage, | |
}, | |
loss_step) | |
sum_policy_loss = 0.0 | |
sum_entropy_loss = 0.0 | |
sum_policy_count = 0.0 | |
sum_value_loss = 0.0 | |
sum_value_count = 0.0 | |
sum_loss = 0.0 | |
if i_episode % 10 == 0: | |
save_path = os.path.join("./checkpoints/", 'chkpt_%d.pt' % i_episode) | |
torch.save(actor_critic.state_dict(), save_path) | |
loss_step += 1 | |
experience = [] | |
break | |
env.close() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment