Last active
August 31, 2020 02:36
-
-
Save austospumanto/6205276f84cd4dde38f3ce17dddccdb3 to your computer and use it in GitHub Desktop.
processit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
System/Runtime Requirements: | |
>=Python3.7 | |
Linux / Mac | |
>=2 CPU Cores | |
Must pip install to run `processit`: | |
pickle5, tqdm | |
Must pip install to run the tests and some functions: | |
numpy, pandas, seaborn | |
""" | |
import contextlib | |
import dataclasses as dc | |
import math | |
import multiprocessing as mp | |
import os | |
import platform | |
import struct | |
import time | |
import traceback | |
import warnings | |
from functools import wraps | |
from itertools import chain | |
from multiprocessing import cpu_count | |
from multiprocessing.connection import Connection | |
from typing import Optional, List, Any, Tuple, Callable, Dict, Collection, Iterable, Union | |
import pickle5 | |
from tqdm import tqdm | |
PdDataFrameT = "pandas.DataFrame" | |
PdIndexT = "pandas.Index" | |
ArgsT = Tuple[Any, ...] | |
KwargsT = Dict[str, Any] | |
########### | |
# GLOBALS # | |
########### | |
_g_worker_index = None | |
_g_verbose = 0 | |
_g_print_lock = mp.Lock() | |
########## | |
# PUBLIC # | |
########## | |
def timeit(): | |
def outer(func): | |
@wraps(func) | |
def inner(*args, **kwargs): | |
start_time = time.time() | |
res = func(*args, **kwargs) | |
interval = time.time() - start_time | |
_maybe_print("Time for '%s': %0.3f seconds" % (func.__qualname__, interval)) | |
return res | |
return inner | |
return outer | |
def set_verbose(v: Union[bool, int]) -> None: | |
global _g_verbose | |
_g_verbose = int(v) | |
@timeit() | |
def processit( | |
todos: List[dict], | |
desc: Optional[str] = None, | |
max_nprocs: int = cpu_count(), | |
common_kwargs=None, | |
) -> List[Any]: | |
n_todos = len(todos) | |
max_nprocs = max_nprocs or cpu_count() | |
assert len(todos) > 0, n_todos | |
assert isinstance(next(iter(todos)), dict), type(next(iter(todos))) | |
if platform.system() not in ("Linux", "Darwin"): | |
return [ | |
t["target"](*t["args"] or tuple(), **t["kwargs"] or dict(), **common_kwargs or dict(),) | |
for t in todos | |
] | |
import gc | |
gc.collect() | |
with _WorkerGlobals.context(todos=todos, common_kwargs=common_kwargs): | |
n_workers = min(n_todos, max_nprocs) | |
m_lock = mp.Lock() | |
m_todo_index = mp.Value("i", 0) | |
handles: List[_WorkerHandle] = _start_workers( | |
n=n_workers, m_lock=m_lock, m_todo_index=m_todo_index | |
) | |
return _gather_results(handles=handles, desc=desc, n_todos=n_todos) | |
@timeit() | |
def processit_chunky( | |
todos: List[dict], | |
desc: Optional[str] = None, | |
max_nprocs: int = cpu_count(), | |
common_kwargs=None, | |
) -> List[Any]: | |
n_todos = len(todos) | |
assert len(todos) > 0, n_todos | |
assert isinstance(next(iter(todos)), dict), type(next(iter(todos))) | |
todos = _chunkify_todos(todos, nchunks=max_nprocs) | |
return list( | |
chain.from_iterable( | |
processit(todos=todos, desc=desc, max_nprocs=max_nprocs, common_kwargs=common_kwargs,) | |
) | |
) | |
def processit_map( | |
target: Callable, | |
map_onto: List, | |
desc: Optional[str] = None, | |
max_nprocs=cpu_count(), | |
common_kwargs=None, | |
chunky=False, | |
) -> List[Any]: | |
todos = [dict(target=target, args=(args,)) for args in map_onto] | |
kw = dict(todos=todos, desc=desc, max_nprocs=max_nprocs, common_kwargs=common_kwargs) | |
if chunky: | |
return processit_chunky(**kw) | |
else: | |
return processit(**kw) | |
def processit_map_identity( | |
targets: List[Callable], desc: Optional[str] = None, common_kwargs=None | |
) -> List[Any]: | |
todos = [dict(target=target) for target in targets] | |
return processit(todos=todos, desc=desc, common_kwargs=common_kwargs) | |
def processit_dict( | |
todos_dict: Dict[str, dict], | |
desc: Optional[str] = None, | |
max_nprocs: int = cpu_count(), | |
common_kwargs=None, | |
) -> Dict[str, Any]: | |
keys, todos = zip(*todos_dict.items()) | |
return dict( | |
zip(keys, processit(todos, desc=desc, max_nprocs=max_nprocs, common_kwargs=common_kwargs),) | |
) | |
########### | |
# PRIVATE # | |
########### | |
class _WorkerGlobals: | |
todos = None | |
common_kwargs = None | |
@classmethod | |
@contextlib.contextmanager | |
def context(cls, todos, common_kwargs): | |
orig_todos = cls.todos | |
orig_common_kwargs = cls.common_kwargs | |
cls.todos = todos | |
cls.common_kwargs = common_kwargs or {} | |
yield | |
cls.todos = orig_todos | |
cls.common_kwargs = orig_common_kwargs | |
@classmethod | |
def read_assigned_work(cls, todo_index: int) -> Tuple[Callable, ArgsT, KwargsT]: | |
todo = cls.todos[todo_index] | |
target = todo["target"] | |
args = todo.get("args") or () | |
kwargs = todo.get("kwargs") or {} | |
if cls.common_kwargs: | |
kwargs = {**kwargs, **cls.common_kwargs} | |
return target, args, kwargs | |
@dc.dataclass(frozen=True) | |
class _WorkerHandle: | |
__slots__ = ("index", "proc", "parent_conn", "child_conn") | |
index: int | |
proc: mp.Process | |
parent_conn: Connection | |
child_conn: Connection | |
def join_n_close(self) -> None: | |
self.proc.join(timeout=None) | |
self.proc.close() | |
self.parent_conn.close() | |
self.child_conn.close() | |
# noinspection PyProtectedMember | |
class _WorkerComms: | |
def __init__(self, conn: Connection): | |
self.conn = conn | |
def send_pickled(self, obj: Any) -> None: | |
self.send_bytes(payload=self.serialize(obj)) | |
def serialize(self, obj: Any) -> bytes: | |
return pickle5.dumps(obj, protocol=5) | |
def deserialize(self, buf: bytes) -> Any: | |
with memoryview(buf) as buf_view: | |
return pickle5.loads(buf_view) | |
def receive_pickled(self) -> Any: | |
payload_bytes = self.receive_bytes() | |
return self.deserialize(payload_bytes) | |
def send_bytes(self, payload: bytes) -> None: | |
self.send(_encode_longint(i=len(payload))) | |
self.send(payload) | |
def receive_bytes(self) -> bytes: | |
payload_length = _decode_longint(i=self.recv(8)) | |
return self.recv2(payload_length) | |
def recv(self, nbytes: int) -> bytes: | |
# noinspection PyUnresolvedReferences | |
return self.conn._recv(nbytes).getvalue() | |
def recv2(self, nbytes: int) -> bytes: | |
buf = bytearray(nbytes) | |
with memoryview(buf) as buf_view: | |
# noinspection PyUnresolvedReferences | |
handle = self.conn._handle | |
remaining = nbytes | |
bytes_written = 0 | |
while remaining > 0: | |
n = os.readv(handle, (buf_view,)) | |
if n == 0: | |
if remaining == nbytes: | |
raise EOFError | |
else: | |
raise OSError("got end of file during message") | |
buf_view = buf_view[n:] | |
remaining -= n | |
bytes_written += n | |
return buf | |
def send(self, bs: bytes) -> None: | |
with memoryview(bs) as bs_view: | |
# noinspection PyUnresolvedReferences | |
self.conn._send(bs_view) | |
def _encode_longint(i: int) -> bytes: | |
return struct.pack("<Q", i) | |
def _decode_longint(i: bytes) -> int: | |
return struct.unpack("<Q", i)[0] | |
def _procprint(*a, **kw) -> None: | |
global _g_print_lock | |
with _g_print_lock: | |
print(*a, **kw) | |
def _maybe_procprint(*a, **kw) -> None: | |
global _g_verbose | |
if _g_verbose >= 2: | |
_procprint() | |
_procprint(*a, **kw) | |
def _maybe_print(*a, **kw) -> None: | |
global _g_verbose | |
if _g_verbose >= 1: | |
print(*a, **kw) | |
def _get_todo_index(m_lock: mp.Lock, m_todo_index: mp.Value) -> Optional[int]: | |
with m_lock: | |
if m_todo_index.value < len(_WorkerGlobals.todos): | |
ret = m_todo_index.value | |
m_todo_index.value += 1 | |
return ret | |
def _worker_main( | |
worker_index: int, m_lock: mp.Lock, m_todo_index: mp.Value, child_conn: Connection | |
): | |
global _g_worker_index | |
global _g_verbose | |
# You may want to comment this out | |
warnings.catch_warnings(record=False) | |
# For debugging inside your target function. | |
# Example: `print(f'Worker {_g_worker_index} beginning work!')` | |
_g_worker_index = worker_index | |
prefix = f"In worker_index={worker_index} --" | |
worker_comms = _WorkerComms(child_conn) | |
try: | |
while True: | |
todo_index = _get_todo_index(m_lock, m_todo_index) | |
if todo_index is None: | |
break | |
_maybe_procprint(f"{prefix} Starting todo_index={todo_index}...") | |
target, args, kwargs = _WorkerGlobals.read_assigned_work(todo_index=todo_index) | |
result = target(*args, **kwargs) | |
_maybe_procprint( | |
f"{prefix} Finished todo_index={todo_index}. Sending pickled result to parent..." | |
) | |
worker_comms.send_pickled(obj=(todo_index, result)) | |
_maybe_procprint( | |
f"{prefix} Sent pickled result for todo_index={todo_index}. Looping..." | |
) | |
except Exception as e: | |
_procprint( | |
f"Exception when executing function in subprocess. Will print traceback. Error: {repr(e)}" | |
) | |
traceback.print_exc() | |
raise | |
finally: | |
_maybe_procprint(f"Queue is empty for worker_index={worker_index}. Cleaning up.") | |
child_conn.close() | |
@timeit() | |
def _gather_results(handles: List[_WorkerHandle], desc: Optional[str], n_todos: int) -> List[Any]: | |
def terminate(): | |
for h in handles: | |
if h.proc.is_alive(): | |
with contextlib.suppress(ValueError): | |
h.proc.terminate() | |
def cleanup(): | |
for h in handles: | |
h.join_n_close() | |
try: | |
return __gather_results(handles=handles, desc=desc, n_todos=n_todos) | |
except KeyboardInterrupt: | |
print( | |
f"In {_gather_results.__qualname__} -- Received KeyboardInterrupt. Terminating workers." | |
) | |
terminate() | |
raise | |
finally: | |
cleanup() | |
def __gather_results(handles: List[_WorkerHandle], desc: Optional[str], n_todos: int) -> List[Any]: | |
global _g_verbose | |
prefix = f"In {__gather_results.__qualname__} --" | |
# The order of `results` will match the order of `handles`' assigned todos | |
results = [None] * n_todos | |
if desc: | |
tq: tqdm = tqdm(total=n_todos, desc=desc) | |
conn2handle: Dict[Connection, _WorkerHandle] = {h.parent_conn: h for h in handles} | |
# Loop until all workers' outputs/results have been received | |
n_finished = 0 | |
while n_finished < n_todos: | |
# `parent_conn` for every worker with a result ready | |
ready_conns = mp.connection.wait(list(conn2handle)) | |
# Gather results from finished workers | |
for conn in ready_conns: | |
handle = conn2handle[conn] | |
try: | |
if conn.poll(0): | |
todo_index, payload = _WorkerComms(conn=conn).receive_pickled() | |
results[todo_index] = payload | |
_maybe_procprint( | |
f"{prefix} Gathered todo_index={todo_index} from worker_index={handle.index}" | |
) | |
if desc: | |
# noinspection PyUnboundLocalVariable | |
tq.update(1) | |
_maybe_procprint("") | |
n_finished += 1 | |
except EOFError: | |
continue | |
except OSError: | |
continue | |
if desc: | |
# Clean up the progress indicator | |
tq.close() | |
print() # Take this out at your own risk! >:) | |
return results | |
def _create_worker(worker_index: int, m_lock, m_todo_index) -> _WorkerHandle: | |
parent_conn, child_conn = mp.Pipe(duplex=False) | |
proc = mp.Process( | |
target=_worker_main, | |
kwargs=dict( | |
worker_index=worker_index, | |
m_lock=m_lock, | |
m_todo_index=m_todo_index, | |
child_conn=child_conn, | |
), | |
name=f"processit-{worker_index}", | |
daemon=False, | |
) | |
return _WorkerHandle( | |
index=worker_index, proc=proc, parent_conn=parent_conn, child_conn=child_conn | |
) | |
@timeit() | |
def _start_workers(n: int, m_lock, m_todo_index) -> List[_WorkerHandle]: | |
_maybe_print(f"In {_start_workers.__qualname__} -- Starting {n} workers") | |
handles = [_create_worker(i, m_lock, m_todo_index) for i in range(n)] | |
for h in handles: | |
h.proc.start() | |
h.child_conn.close() | |
return handles | |
def _execute_todo(todo: dict, common_kwargs) -> Any: | |
target = todo["target"] | |
args = todo.get("args") or () | |
kwargs = todo.get("kwargs") or {} | |
kwargs = {**kwargs, **common_kwargs} | |
return target(*args, **kwargs) | |
def _execute_todos_serially(*, chunk_of_todos: List[dict], **common_kwargs) -> List[Any]: | |
return [_execute_todo(todo, common_kwargs) for todo in chunk_of_todos] | |
def _chunkify_todos(todos: List[dict], nchunks: int) -> List[dict]: | |
chunksz: int = get_chunksz(len(todos), nchunks=nchunks) | |
chunks_of_todos: Iterable[List[dict]] = chunks(todos, chunksz=chunksz) | |
return [ | |
{"target": _execute_todos_serially, "kwargs": {"chunk_of_todos": chunk_of_todos},} | |
for chunk_of_todos in chunks_of_todos | |
] | |
def chunks(l: Collection, chunksz: int = 0, nchunks: int = 0) -> Iterable: | |
assert bool(chunksz) ^ bool(nchunks), (chunksz, nchunks) | |
if chunksz: | |
return _n_sized_chunks(l=l, n=chunksz) | |
else: | |
return _exactly_n_chunks(l=l, n=nchunks) | |
def _n_sized_chunks(l: Collection, n: int) -> Iterable: | |
"""Yield successive `n`-sized chunks from `l`.""" | |
for i in range(0, len(l), n): | |
s = i | |
e = i + n | |
yield l[s:e] | |
def _exactly_n_chunks(l: Collection, n: int) -> Iterable: | |
"""Yield exactly n chunks from `l`.""" | |
chunksz = int(math.floor(len(l) / n)) | |
chunkszs = [chunksz for _ in range(n)] | |
for i in range(n): | |
if sum(chunkszs) < len(l): | |
chunkszs[i] += 1 | |
s = 0 | |
chunkszs = iter(chunkszs) | |
while s < len(l): | |
e = s + next(chunkszs) | |
yield l[s:e] | |
s = e | |
def get_chunksz(n_items_to_process: int, nchunks: int = cpu_count()) -> int: | |
return int(math.ceil(n_items_to_process / nchunks)) | |
######### | |
# TESTS # | |
######### | |
class _Tests: | |
@staticmethod | |
def test_encode_decode_payload_length_header(): | |
obj = [1, "a", None] | |
payload = pickle5.dumps(obj, protocol=5) | |
header = _encode_longint(len(payload)) | |
assert len(header) == 8 | |
assert len(payload) == 23 == _decode_longint(header) | |
@staticmethod | |
def test_processit_chunked(): | |
import numpy as np | |
import pandas as pd | |
njobs = 6 | |
sz = int(1e4) | |
shape = (sz, sz) | |
df = pd.DataFrame(np.arange(shape[0] * shape[1]).reshape(shape)) | |
def skew_of_ith_row(i_, df_): | |
return df_.loc[i_].skew() | |
normal_time = time.time() | |
normal = df.skew(axis=1) | |
normal_time = time.time() - normal_time | |
serial_time = time.time() | |
expected = [skew_of_ith_row(i, df) for i in df.index] | |
serial_time = time.time() - serial_time | |
parallel_time = time.time() | |
actual = processit( | |
[dict(target=skew_of_ith_row, args=(i, df)) for i in df.index], max_nprocs=njobs, | |
) | |
parallel_time = time.time() - parallel_time | |
parallel_chunky_time = time.time() | |
actual_chunky = processit_chunky( | |
[dict(target=skew_of_ith_row, args=(i, df)) for i in df.index], max_nprocs=njobs, | |
) | |
parallel_chunky_time = time.time() - parallel_chunky_time | |
times = dict( | |
normal_time=normal_time, | |
serial_time=serial_time, | |
parallel_time=parallel_time, | |
parallel_chunky_time=parallel_chunky_time, | |
) | |
print(times) | |
assert expected == actual == actual_chunky == list(normal.to_numpy()) | |
assert parallel_chunky_time < parallel_time, times | |
assert parallel_time < normal_time / 2, times | |
assert parallel_time < serial_time / 2, times | |
if __name__ == "__main__": | |
_Tests.test_encode_decode_payload_length_header() | |
_Tests.test_processit_chunked() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment