Skip to content

Instantly share code, notes, and snippets.

@dong-zeyu
Created August 24, 2023 15:06
Show Gist options
  • Save dong-zeyu/8a9dd4b2a862fab217cb0279437f047a to your computer and use it in GitHub Desktop.
Save dong-zeyu/8a9dd4b2a862fab217cb0279437f047a to your computer and use it in GitHub Desktop.
A multiple processing helper using COW to boost speed and optimize memory usage
import multiprocessing as mp
import numpy as np
global_data = {}
def array_to_numpy(output_type, output):
if output_type == "list":
output = [np.ctypeslib.as_array(arr).reshape(shape) for arr, shape in output]
else:
arr, shape = output
output = np.ctypeslib.as_array(arr).reshape(shape)
return output
def init_worker(i, output_type_, output_):
global input, fn, output, output_type
output_type = output_type_
input = global_data[i][0]
fn = global_data[i][1]
output = array_to_numpy(output_type, output_)
def process_data_wrapper(i):
result = fn(input[i])
if output_type == "list":
for j in range(len(output)):
output[j][i] = result[j]
else:
output[i] = result
def initialize_output(size, output_t):
if isinstance(output_t, np.ndarray):
return mp.Array(np.ctypeslib.as_ctypes_type(output_t.dtype), size * output_t.shape[0] * output_t.shape[1], lock=False), (size, *output_t.shape)
elif isinstance(output_t, int):
return mp.Array("q", size, lock=False), (size,)
elif isinstance(output_t, float):
return mp.Array("d", size, lock=False), (size,)
elif isinstance(output_t, bool):
return mp.Array("b", size), (size,)
else:
raise Exception(f"Cannot handle type {type(output_t)} for output")
def map(fn, input):
# Input has to be a iterable
input = list(input)
# Test the function output
# Output has to be either a numpy array or primitive type (int, float, bool)
# List or tuple are allowed if they contain only supported types
output_t = fn(input[0])
# Preallocate the output buffer in the main process
if isinstance(output_t, tuple) or isinstance(output_t, list):
output_type = "list"
output = []
for i in output_t:
output.append(initialize_output(len(input), i))
else:
output_type = "array"
output = initialize_output(len(input), output_t)
# Utilize the linux COW to share the data between processes
# The input data is read only, so the memory is shared between processes
while True:
i = np.random.randint(0, 0xffffffff)
if i not in global_data:
global_data[i] = (input, fn)
break
try:
with mp.Pool(mp.cpu_count(), initializer=init_worker, initargs=(i, output_type, output)) as pool:
pool.map(process_data_wrapper, list(range(len(input))))
return array_to_numpy(output_type, output)
finally:
del global_data[i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment