Skip to content

Instantly share code, notes, and snippets.

@soumith
Forked from anonymous/parameter_server.py
Created January 1, 2017 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/d6b62235d8b830247cb77094df0699c4 to your computer and use it in GitHub Desktop.
Save soumith/d6b62235d8b830247cb77094df0699c4 to your computer and use it in GitHub Desktop.
import torch.multiprocessing as mp
from torch.multiprocessing import Semaphore
import sys
if sys.version_info[0] == 3:
Barrier = mp.Barrier
else: # version 2
# from http://stackoverflow.com/a/26703365/117844
class Barrier:
def __init__(self, n):
self.n = n
self.count = 0
self.mutex = Semaphore(1)
self.barrier = Semaphore(0)
def wait(self):
self.mutex.acquire()
self.count = self.count + 1
self.mutex.release()
if self.count == self.n: self.barrier.release()
self.barrier.acquire()
self.barrier.release()
class ParameterServer(object):
def __init__(self, n_processes):
self.queue = mp.Queue()
self.n_processes = n_processes
self.barrier = Barrier(n_processes)
def __getstate__(self):
return (self.queue, self.barrier, self.n_processes)
def __setstate__(self, state):
self.queue, self.barrier, self.n_processes = state
def sync_model(self, rank, model=None):
if rank == 0:
assert model is not None
for i in range(self.n_processes-1):
self.queue.put(model)
else:
model = self.queue.get()
# clone the gradients to break the sharing
for param in model.parameters():
param._grad = param.grad.clone()
self.barrier.wait()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment