Skip to content

Instantly share code, notes, and snippets.

@seahrh
Last active August 6, 2021 07:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seahrh/fb4ea18dc1fab11701905e73d777c185 to your computer and use it in GitHub Desktop.
Save seahrh/fb4ea18dc1fab11701905e73d777c185 to your computer and use it in GitHub Desktop.
multiprocessing: update shared dictionary by multiple processes
import json
import multiprocessing
def dowork(qids, shared_dict, lock):
lock.acquire() # type: ignore
for qid in qids:
# update shared_dict
pass
lock.release()
def doparallel(
n_proc: int,
batch_size: int,
qids: List[str],
directory: str,
timeout: int = 360,
) -> Dict[str, str]:
res: Dict[str, str] = {}
with multiprocessing.Manager() as m:
sd = m.dict() # type: ignore
lock = m.Lock() # type: ignore
with multiprocessing.Pool(processes=n_proc) as pool:
batches = []
for i in tqdm(range(0, len(qids), batch_size)):
_qids = qids[i : i + batch_size]
batches.append((_qids, sd, lock))
i += batch_size
if len(batches) == n_proc or i >= len(qids):
rs = []
for arg_tuple in batches:
rs.append(pool.apply_async(dowork, arg_tuple))
for r in rs:
r.wait(timeout)
res = dict(sd)
e_len = min(i, len(qids))
if len(res) != e_len:
raise ValueError(
f"shared_dict did not have expected length. len(res)={len(res)}, expected={e_len}"
)
with open(f"{directory}/pred.json", "w") as f:
json.dump(res, f)
batches = []
res = dict(sd)
print(f"len(res)={len(res)}, len(sd)={len(sd)}")
return res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment