Skip to content

Instantly share code, notes, and snippets.

@bartvm
Last active December 3, 2015 20:05
Show Gist options
  • Save bartvm/3e3d169a7d4154353981 to your computer and use it in GitHub Desktop.
Save bartvm/3e3d169a7d4154353981 to your computer and use it in GitHub Desktop.
Pseudo-code multi-GPU

Each worker is roughly of the following form:

# socket is some form of communication
def train(id, device, batch_queue, socket):
  construct_graph_and_compile()
  if id == 1:
      # This is the first process, so memory map the parameters
      m_params = memory_map(params)
      # Send them to the main thread
      socket.send(m_params)
  else:
      # The parameters should already be memory mapped
      m_params_name = socket.get()
      m_params = memory_map_from(m_params_name)
  i = 0
  while True:
      i += 1
      sgd(batch_queue.get())
      if i % 10 == 0:
          exchange_params(params, m_params)

The main thread is of the form:

def main():
    pool = create_pool()
    # Add one worker to the queue to warm the cache and create memory mapped params
    pool.add(workers[0], 'gpu0', batch_queue, socket)
    socket.get()  # Get shared memory parameters from first worker
    # Start all the other workers
    pool.add(workers[1:], 'gpu1-7', batch_queue, socket)
    # Send the memory mapped parameters to the other workers
    socket.push(m_params)
    # Start feeding batches to the workers
    while True:
        bath_queue.push(get_batch())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment