Skip to content

Instantly share code, notes, and snippets.

@dirkgr
Last active October 19, 2018 23:43
Show Gist options
  • Save dirkgr/26cbaa71f3591bf8135453f1ecac6ef2 to your computer and use it in GitHub Desktop.
Save dirkgr/26cbaa71f3591bf8135453f1ecac6ef2 to your computer and use it in GitHub Desktop.
A map function that runs every iteration in a separate process, in parallel
from typing import *
import multiprocessing as mp
import multiprocessing.connection
def map_per_process(fn, input_sequence: Iterable) -> Iterable:
pipeno_to_pipe: Dict[int, multiprocessing.connection.Connection] = {}
pipeno_to_process: Dict[int, mp.Process] = {}
def process_one_item(send_pipe: multiprocessing.connection.Connection, item):
try:
processed_item = fn(item)
except Exception as e:
send_pipe.send((None, e))
else:
send_pipe.send((processed_item, None))
send_pipe.close()
def yield_from_pipes(pipes: List[multiprocessing.connection.Connection]):
for pipe in pipes:
result, error = pipe.recv()
pipeno = pipe.fileno()
del pipeno_to_pipe[pipeno]
pipe.close()
process = pipeno_to_process[pipeno]
process.join()
del pipeno_to_process[pipeno]
if error is None:
yield result
else:
raise error
try:
for item in input_sequence:
receive_pipe, send_pipe = mp.Pipe(duplex=False)
process = mp.Process(target = process_one_item, args=(send_pipe, item))
pipeno_to_pipe[receive_pipe.fileno()] = receive_pipe
pipeno_to_process[receive_pipe.fileno()] = process
process.start()
# read out the values
timeout = 0 if len(pipeno_to_process) < mp.cpu_count() else None
# If we have fewer processes going than we have CPUs, we just pick up the values
# that are done. If we are at the process limit, we wait until one of them is done.
ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=timeout)
yield from yield_from_pipes(ready_pipes)
# yield the rest of the items
while len(pipeno_to_process) > 0:
ready_pipes = multiprocessing.connection.wait(pipeno_to_pipe.values(), timeout=None)
yield from yield_from_pipes(ready_pipes)
finally:
for process in pipeno_to_process.values():
if process.is_alive():
process.terminate()
def map_in_chunks(fn, chunk_size: int, input_sequence: Iterable) -> Iterable:
def process_chunk(chunk: List) -> List:
return list(map(fn, chunk))
processed_chunks = map_per_process(process_chunk, slices(chunk_size, input_sequence))
for processed_chunk in processed_chunks:
yield from processed_chunk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment