Skip to content

Instantly share code, notes, and snippets.

@zihangdai
Last active May 31, 2017 18:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zihangdai/fc8f76fbb8a0f6323a6b31e6d98ceb50 to your computer and use it in GitHub Desktop.
Save zihangdai/fc8f76fbb8a0f6323a6b31e6d98ceb50 to your computer and use it in GitHub Desktop.
Multi-thread single-GPU AC learning
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
import random
import threading
cuda = True
variable = lambda x : Variable(x.cuda()) if cuda else Variable(x)
def collect(env, model, replay_mem, locks):
net_lock, mem_lock = locks
curr_state = torch.rand(1, 3, 32, 32)
for i in range(1000):
action = model(variable(curr_state))
reward = np.random.rand(1)
next_state = torch.rand(1, 3, 32, 32)
transition = (curr_state, action, next_state, reward)
with mem_lock:
replay_mem.append(transition)
curr_state = next_state
if i % 100 == 0:
print 'collect {}'.format(i)
def train(model, replay_mem, locks):
net_lock, mem_lock = locks
optimizer = torch.optim.Adagrad(model.parameters(), 1e-5)
for i in range(1000):
while len(replay_mem) < 10:
continue
with mem_lock:
samples = random.sample(replay_mem, 10)
curr_state, action, next_state, reward = zip(*samples)
curr_state = torch.cat(curr_state)
next_state = torch.cat(next_state)
# print 'forward'
out = model(variable(curr_state))
loss = out.mean()
optimizer.zero_grad()
# print 'backward'
loss.backward()
# print 'optimization'
optimizer.step()
if i % 100 == 0:
print 'train {}'.format(i)
if __name__ == '__main__':
model = nn.Sequential(
nn.Conv2d(3, 100, 3, padding=1),
nn.Conv2d(100, 3, 3, padding=1),
)
if cuda: model.cuda()
replay_mem = []
locks = [threading.Lock() for _ in range(2)]
thread_collect = threading.Thread(target=collect, args=(None, model, replay_mem, locks))
thread_train = threading.Thread(target=train, args=(model, replay_mem, locks))
threads = [thread_collect, thread_train]
for t in threads:
t.start()
for t in threads:
t.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment