Skip to content

Instantly share code, notes, and snippets.

@hrz6976
Last active October 7, 2022 05:48
Show Gist options
  • Save hrz6976/262b2220225941424ca2779be6c1a602 to your computer and use it in GitHub Desktop.
Save hrz6976/262b2220225941424ca2779be6c1a602 to your computer and use it in GitHub Desktop.
# Run func(args) concurrently and aggregate the results
# author: @hehao98 <heh@pku.edu.cn>, @12f23eddde <12f23eddde@gmail.com>
from typing import Callable, Iterable, Optional, Any, Tuple, TypeVar, Union
import time
import traceback
from multiprocess import Pool, cpu_count
from tqdm.auto import tqdm
import pandas as pd
T = TypeVar("T")
SupportedFuncType = Callable[[Any], T]
AggFuncType = Callable[[Optional[T], Optional[T]], T]
def _get_default_n_workers() -> int:
return max(4, int(cpu_count() / 3 * 2))
def agg_sum(r: Optional[Union[int, float]], s: Optional[Union[int, float]]) -> Union[int, float]:
if r is None:
return 0
return r+s
def agg_append_df(r: Optional[pd.DataFrame], s: Optional[pd.DataFrame]) -> pd.DataFrame:
if r is None:
return pd.DataFrame()
return pd.concat([r, s])
def parallel(
func: SupportedFuncType,
args: Iterable[Any],
agg_func: Optional[AggFuncType] = None,
n_workers: Optional[int] = None,
total: Optional[int] = None,
progress_bar = tqdm
):
"""
Wraps multiprocessing.pool;
:param func: function to parallel (accepts only 1 parameter)
:param args: iterable containing function arguments
:param agg_func: function to aggregate the results (default: returns None)
:param n_workers: # of worker processes (default: max(4, 2/3 core count))
:param total: # of iterations (default: len(args))
:param progress_bar: tqdm instance (default: tqdm.auto)
---
Example:
```python
def func_to_parallel(a: int = 1, b: int = 2) -> pd.DataFrame:
time.sleep(a+b)
return pd.DataFrame([{"sum": a+b}])
def wrapper(args: Tuple[int, int]) -> pd.DataFrame:
return func_to_parallel(*args)
res = parallel(wrapper, [(0, 1), (1, 2), (2, 3)], agg_func=agg_append_df)
```
"""
if not n_workers:
n_workers = _get_default_n_workers()
if not total:
total = len(args)
pool = Pool(n_workers)
try:
start = time.time()
# using multiprocess.imap
with progress_bar(total=total) as t:
r = None
if agg_func is not None:
r = agg_func(None, None)
for i in pool.imap_unordered(func, args):
if agg_func is not None:
r = agg_func(r, i)
t.set_postfix({'func': func.__name__, "time": "%.1fs" % (time.time() - start)})
t.update()
return r
except Exception as e:
print(e, flush=True)
traceback.print_exc()
finally:
pool.close() # close the pool to any new jobs
pool.join() # cleanup the closed worker processes
if __name__ == "__main__":
def func_to_parallel(a: int = 1, b: int = 2) -> pd.DataFrame:
time.sleep(a+b)
return pd.DataFrame([{"sum": a+b}])
def wrapper(args: Tuple[int, int]) -> pd.DataFrame:
return func_to_parallel(*args)
res = parallel(wrapper, [(0, 1), (1, 2), (2, 3)], agg_func=agg_append_df)
print(res.head())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment