Skip to content

Instantly share code, notes, and snippets.

@AlexPiche
Last active August 17, 2017 14:01
Show Gist options
  • Save AlexPiche/5f3dcd2d7129a4454854773e29da541c to your computer and use it in GitHub Desktop.
Save AlexPiche/5f3dcd2d7129a4454854773e29da541c to your computer and use it in GitHub Desktop.
import torch.multiprocessing as mp
from torch import nn
import torch
from torch.autograd import Variable
num_processes = 3
import gym
import roboschool
def train(args, rank):
print(rank)
env = gym.make('MountainCarContinuous-v0')
linear = nn.Linear(2, 200)
rnn = nn.LSTMCell(200, 128)
observation = env.reset()
observation = torch.from_numpy(observation)
observation = Variable(observation.float().unsqueeze(0))
cx = Variable(torch.zeros(1, 128))
hx = Variable(torch.zeros(1, 128))
x = linear(observation)
x = x.view(-1, 200)
hx, cx = rnn(x, (hx, cx))
print("all good")
processes = []
args = []
if __name__ == '__main__':
train(args, 0)
for rank in range(1, num_processes):
p = mp.Process(target=train, args=(args, rank))
p.start()
processes.append(p)
for p in processes:
p.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment