Skip to content

Instantly share code, notes, and snippets.

@quadrismegistus
Last active May 4, 2021 13:21
Show Gist options
  • Save quadrismegistus/cd51daacbe6f0edcadc46248538cfbb0 to your computer and use it in GitHub Desktop.
Save quadrismegistus/cd51daacbe6f0edcadc46248538cfbb0 to your computer and use it in GitHub Desktop.
Easy parallel processing in python with progress bar
"""
Simple mofo'n parallelism with progress bar. Born of frustration with p_tqdm.
"""
def pmap_do(inp):
func,obj,args,kwargs = inp
return func(obj,*args,**kwargs)
def pmap_iter(func, objs, args=[], kwargs={}, num_proc=4, use_threads=False, progress=True, desc=None, **y):
"""
Yields results of func(obj) for each obj in objs
Uses multiprocessing.Pool(num_proc) for parallelism.
If use_threads, use ThreadPool instead of Pool.
Results in any order.
"""
# imports
from tqdm import tqdm
# if parallel
if not desc: desc=f'Mapping {func.__name__}()'
if desc: desc=f'{desc} [x{num_proc}]'
if num_proc>1 and len(objs)>1:
# real objects
objects = [(func,obj,args,kwargs) for obj in objs]
# create pool
import multiprocessing as mp
pool=mp.Pool(num_proc) if not use_threads else mp.pool.ThreadPool(num_proc)
# yield iter
iterr = pool.imap(pmap_do, objects)
for res in tqdm(iterr,total=len(objs),desc=desc) if progress else iterr:
yield res
# Close the pool?
pool.close()
pool.join()
else:
# yield
for obj in (tqdm(objs,desc=desc) if progress else objs):
yield func(obj,*args,**kwargs)
def pmap(*x,**y):
"""
Non iterator version of pmap_iter
"""
# return as list
return list(pmap_iter(*x,**y))
def do_pmap_group(obj):
# unpack
func,group_df,group_key,group_name = obj
# load from cache?
if type(group_df)==str:
group_df=pd.read_pickle(group_df)
# run func
outdf=func(group_df)
# annotate with groupnames on way out
if type(group_name) not in {list,tuple}:group_name=[group_name]
for x,y in zip(group_key,group_name):
outdf[x]=y
# return
return outdf
def pmap_groups(func,df_grouped,use_cache=True,**attrs):
import os,tempfile,pandas as pd
from tqdm import tqdm
# get index/groupby col name(s)
group_key=df_grouped.grouper.names
# if not using cache
# if not use_cache or attrs.get('num_proc',1)<2:
if not use_cache:
objs=[
(func,group_df,group_key,group_name)
for group_name,group_df in df_grouped
]
else:
objs=[]
tmpdir=tempfile.mkdtemp()
for i,(group_name,group_df) in enumerate(tqdm(list(df_grouped),desc='Preparing input')):
tmp_path = os.path.join(tmpdir, str(i)+'.pkl')
# print([i,group_name,tmp_path,group_df])
group_df.to_pickle(tmp_path)
objs+=[(func,tmp_path,group_key,group_name)]
# desc?
if not attrs.get('desc'): attrs['desc']=f'Mapping {func.__name__}'
return pd.concat(
pmap(
do_pmap_group,
objs,
**attrs
)
).set_index(group_key)
def pmap_df(df, func, num_proc=1):
df_split = np.array_split(df, num_proc)
df = pd.concat(pmap(func, df_split, num_proc=num_proc))
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment