Skip to content

Instantly share code, notes, and snippets.

@gravicle
Last active August 26, 2021 19:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gravicle/27f622a7d3c335657f6fd8925c37352c to your computer and use it in GitHub Desktop.
Save gravicle/27f622a7d3c335657f6fd8925c37352c to your computer and use it in GitHub Desktop.
import functools
import multiprocessing as mp
from tqdm import tqdm
import math
from .device import *
try:
tp.set_start_method('spawn')
except RuntimeError:
pass
N_CPU = cpu_count()
def parallel_map(task, iter, process_count: int = N_CPU, show_progress=False, debug=False):
"""Runs a map across cores
Args:
task (function): The task to run
iter ([any]): list objects mapped to task
process_count (int): Number of processes to spawn
Returns:
list of return values of task
"""
if process_count > len(iter):
process_count = len(iter)
if debug:
for item in tqdm(iter):
task(item)
else:
with mp.Pool(process_count) as p:
if show_progress:
result = list(tqdm(p.imap(task, iter), total=len(iter)))
else:
result = p.map(task, iter)
return result
import multiprocessing as mp
import torch
def cpu_count():
return mp.cpu_count()
def gpu_count():
return torch.cuda.device_count()
import numpy as np
from torch import rand
from luma_utils.concurrency import *
def task(numbers):
return np.sum(numbers)
def run():
num_shards = [np.arange(i, i+10) for i in range(1000)]
result_parallel = parallel_map(task, num_shards)
result_seq = [task(n) for n in num_shards]
assert(result_seq == result_parallel)
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment