Each worker is roughly of the following form:
# socket is some form of communication
def train(id, device, batch_queue, socket):
construct_graph_and_compile()
if id == 1:
# This is the first process, so memory map the parameters
m_params = memory_map(params)
# Send them to the main thread
socket.send(m_params)
else:
# The parameters should already be memory mapped
m_params_name = socket.get()
m_params = memory_map_from(m_params_name)
i = 0
while True:
i += 1
sgd(batch_queue.get())
if i % 10 == 0:
exchange_params(params, m_params)
The main thread is of the form:
def main():
pool = create_pool()
# Add one worker to the queue to warm the cache and create memory mapped params
pool.add(workers[0], 'gpu0', batch_queue, socket)
socket.get() # Get shared memory parameters from first worker
# Start all the other workers
pool.add(workers[1:], 'gpu1-7', batch_queue, socket)
# Send the memory mapped parameters to the other workers
socket.push(m_params)
# Start feeding batches to the workers
while True:
bath_queue.push(get_batch())