Skip to content

Instantly share code, notes, and snippets.

@JohnAllen
Last active February 8, 2019 19:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JohnAllen/d6117dc781cc61b35d179a28d04f6102 to your computer and use it in GitHub Desktop.
Save JohnAllen/d6117dc781cc61b35d179a28d04f6102 to your computer and use it in GitHub Desktop.
import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
# state shape is [65]
class Policy(nn.Module):
def __init__(self, num_actions):
self.num_cuda_devices = torch.cuda.device_count()
print("Let's use", self.num_cuda_devices, "GPUs!")
super(Policy, self).__init__()
if self.num_cuda_devices > 1:
self.input_layer = nn.Linear(65, 128)
self.hidden_1 = nn.Linear(128, 256)
self.hidden_1 = nn.DataParallel(self.hidden_1)
self.hidden_2 = nn.Linear(256, 128)
self.hidden_2 = nn.DataParallel(self.hidden_2)
self.hidden_state = torch.zeros(5, 1, 256).clone().detach()
elif self.num_cuda_devices == 1:
self.input_layer = nn.Linear(65, 128).cuda()
self.hidden_1 = nn.Linear(128, 256).cuda()
self.hidden_2 = nn.Linear(256, 128).cuda()
self.hidden_state = torch.zeros(5, 1, 256).clone().detach().cuda()
else:
self.input_layer = nn.Linear(65, 128)
self.hidden_1 = nn.Linear(128, 256)
self.hidden_2 = nn.Linear(256, 128)
self.hidden_state = torch.zeros(5, 1, 256).clone().detach()
self.rnn = nn.GRU(256, 256, 5)
self.action_head = nn.Linear(128, num_actions)
self.value_head = nn.Linear(128, 1)
self.action_scores = []
self.saved_actions = []
self.rewards = []
def forward(self, x):
if self.num_cuda_devices == 1:
x = self.input_layer(x.float()).cuda()
x = torch.sigmoid(x).cuda()
x = self.hidden_1(x)
x = torch.tanh(x).cuda()
elif self.num_cuda_devices > 1:
x = self.input_layer(x.float()).cuda()
x = torch.sigmoid(x)
x = self.hidden_1(x).cuda() # fails here with size mismatch, m1: [1 x 43], m2: [128 x 256] at /aten/src/THC/generic/THCTensorMathBlas.cu:266
x = torch.tanh(x)
else:
x = self.input_layer(x.float())
x = torch.sigmoid(x)
x = self.hidden_1(x)
x = torch.tanh(x)
x, self.hidden_state = self.rnn(x.view(1, -1, 256), self.hidden_state.data)
x = self.hidden_2(x.squeeze())
x = Functional.relu(x)
self.action_scores = self.action_head(x)
state_values = self.value_head(x)
return Functional.softmax(self.action_scores, dim=-1), state_values
def act(self, state):
probs, state_values = self.forward(state)
m = Categorical(probs)
action_num = m.sample()
self.saved_actions.append((m.log_prob(action_num), state_values))
return action_num.item()
batch_length = 33
num_epochs = 2
data = Data(args['data_dir'])
Actions = Actions(data)
Reward = Reward(data)
train = Train(data, Actions, Reward)
model = Policy(
Actions.num_actions
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
model.cuda()
model.to(device)
def finish_episode():
R = 0
saved_actions = model.saved_actions
policy_losses = []
value_losses = []
rewards = []
for r in model.rewards[::-1]:
R = r + args.gamma * R
rewards.insert(0, R)
rewards = torch.tensor(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
for (log_prob, value), r in zip(saved_actions, rewards):
reward = r - value.item()
policy_losses.append(-log_prob * reward)
value_losses.append(F.smooth_l1_loss(value, torch.tensor([r])))
optimizer.zero_grad()
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
loss.backward()
optimizer.step()
del model.rewards[:]
del model.saved_actions[:]
running_reward = 1
for episode in range(0, 100000):
state = train.reset(batch_length)
reward = torch.FloatTensor(0)
done = False
msg = None
time_step = 0
while not done:
action_num_item = model.module.act(state)
state, reward, done, msg = train.step(action_num_item)
model.module.rewards.append(reward)
time_step += 1
if done:
break
running_reward = (running_reward * 0.99) + (time_step * 0.01)
loss = finish_episode()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment