Skip to content

Instantly share code, notes, and snippets.

@austospumanto
Last active August 31, 2020 02:36
Show Gist options
  • Save austospumanto/6205276f84cd4dde38f3ce17dddccdb3 to your computer and use it in GitHub Desktop.
Save austospumanto/6205276f84cd4dde38f3ce17dddccdb3 to your computer and use it in GitHub Desktop.
processit
"""
System/Runtime Requirements:
>=Python3.7
Linux / Mac
>=2 CPU Cores
Must pip install to run `processit`:
pickle5, tqdm
Must pip install to run the tests and some functions:
numpy, pandas, seaborn
"""
import contextlib
import dataclasses as dc
import math
import multiprocessing as mp
import os
import platform
import struct
import time
import traceback
import warnings
from functools import wraps
from itertools import chain
from multiprocessing import cpu_count
from multiprocessing.connection import Connection
from typing import Optional, List, Any, Tuple, Callable, Dict, Collection, Iterable, Union
import pickle5
from tqdm import tqdm
PdDataFrameT = "pandas.DataFrame"
PdIndexT = "pandas.Index"
ArgsT = Tuple[Any, ...]
KwargsT = Dict[str, Any]
###########
# GLOBALS #
###########
_g_worker_index = None
_g_verbose = 0
_g_print_lock = mp.Lock()
##########
# PUBLIC #
##########
def timeit():
def outer(func):
@wraps(func)
def inner(*args, **kwargs):
start_time = time.time()
res = func(*args, **kwargs)
interval = time.time() - start_time
_maybe_print("Time for '%s': %0.3f seconds" % (func.__qualname__, interval))
return res
return inner
return outer
def set_verbose(v: Union[bool, int]) -> None:
global _g_verbose
_g_verbose = int(v)
@timeit()
def processit(
todos: List[dict],
desc: Optional[str] = None,
max_nprocs: int = cpu_count(),
common_kwargs=None,
) -> List[Any]:
n_todos = len(todos)
max_nprocs = max_nprocs or cpu_count()
assert len(todos) > 0, n_todos
assert isinstance(next(iter(todos)), dict), type(next(iter(todos)))
if platform.system() not in ("Linux", "Darwin"):
return [
t["target"](*t["args"] or tuple(), **t["kwargs"] or dict(), **common_kwargs or dict(),)
for t in todos
]
import gc
gc.collect()
with _WorkerGlobals.context(todos=todos, common_kwargs=common_kwargs):
n_workers = min(n_todos, max_nprocs)
m_lock = mp.Lock()
m_todo_index = mp.Value("i", 0)
handles: List[_WorkerHandle] = _start_workers(
n=n_workers, m_lock=m_lock, m_todo_index=m_todo_index
)
return _gather_results(handles=handles, desc=desc, n_todos=n_todos)
@timeit()
def processit_chunky(
todos: List[dict],
desc: Optional[str] = None,
max_nprocs: int = cpu_count(),
common_kwargs=None,
) -> List[Any]:
n_todos = len(todos)
assert len(todos) > 0, n_todos
assert isinstance(next(iter(todos)), dict), type(next(iter(todos)))
todos = _chunkify_todos(todos, nchunks=max_nprocs)
return list(
chain.from_iterable(
processit(todos=todos, desc=desc, max_nprocs=max_nprocs, common_kwargs=common_kwargs,)
)
)
def processit_map(
target: Callable,
map_onto: List,
desc: Optional[str] = None,
max_nprocs=cpu_count(),
common_kwargs=None,
chunky=False,
) -> List[Any]:
todos = [dict(target=target, args=(args,)) for args in map_onto]
kw = dict(todos=todos, desc=desc, max_nprocs=max_nprocs, common_kwargs=common_kwargs)
if chunky:
return processit_chunky(**kw)
else:
return processit(**kw)
def processit_map_identity(
targets: List[Callable], desc: Optional[str] = None, common_kwargs=None
) -> List[Any]:
todos = [dict(target=target) for target in targets]
return processit(todos=todos, desc=desc, common_kwargs=common_kwargs)
def processit_dict(
todos_dict: Dict[str, dict],
desc: Optional[str] = None,
max_nprocs: int = cpu_count(),
common_kwargs=None,
) -> Dict[str, Any]:
keys, todos = zip(*todos_dict.items())
return dict(
zip(keys, processit(todos, desc=desc, max_nprocs=max_nprocs, common_kwargs=common_kwargs),)
)
###########
# PRIVATE #
###########
class _WorkerGlobals:
todos = None
common_kwargs = None
@classmethod
@contextlib.contextmanager
def context(cls, todos, common_kwargs):
orig_todos = cls.todos
orig_common_kwargs = cls.common_kwargs
cls.todos = todos
cls.common_kwargs = common_kwargs or {}
yield
cls.todos = orig_todos
cls.common_kwargs = orig_common_kwargs
@classmethod
def read_assigned_work(cls, todo_index: int) -> Tuple[Callable, ArgsT, KwargsT]:
todo = cls.todos[todo_index]
target = todo["target"]
args = todo.get("args") or ()
kwargs = todo.get("kwargs") or {}
if cls.common_kwargs:
kwargs = {**kwargs, **cls.common_kwargs}
return target, args, kwargs
@dc.dataclass(frozen=True)
class _WorkerHandle:
__slots__ = ("index", "proc", "parent_conn", "child_conn")
index: int
proc: mp.Process
parent_conn: Connection
child_conn: Connection
def join_n_close(self) -> None:
self.proc.join(timeout=None)
self.proc.close()
self.parent_conn.close()
self.child_conn.close()
# noinspection PyProtectedMember
class _WorkerComms:
def __init__(self, conn: Connection):
self.conn = conn
def send_pickled(self, obj: Any) -> None:
self.send_bytes(payload=self.serialize(obj))
def serialize(self, obj: Any) -> bytes:
return pickle5.dumps(obj, protocol=5)
def deserialize(self, buf: bytes) -> Any:
with memoryview(buf) as buf_view:
return pickle5.loads(buf_view)
def receive_pickled(self) -> Any:
payload_bytes = self.receive_bytes()
return self.deserialize(payload_bytes)
def send_bytes(self, payload: bytes) -> None:
self.send(_encode_longint(i=len(payload)))
self.send(payload)
def receive_bytes(self) -> bytes:
payload_length = _decode_longint(i=self.recv(8))
return self.recv2(payload_length)
def recv(self, nbytes: int) -> bytes:
# noinspection PyUnresolvedReferences
return self.conn._recv(nbytes).getvalue()
def recv2(self, nbytes: int) -> bytes:
buf = bytearray(nbytes)
with memoryview(buf) as buf_view:
# noinspection PyUnresolvedReferences
handle = self.conn._handle
remaining = nbytes
bytes_written = 0
while remaining > 0:
n = os.readv(handle, (buf_view,))
if n == 0:
if remaining == nbytes:
raise EOFError
else:
raise OSError("got end of file during message")
buf_view = buf_view[n:]
remaining -= n
bytes_written += n
return buf
def send(self, bs: bytes) -> None:
with memoryview(bs) as bs_view:
# noinspection PyUnresolvedReferences
self.conn._send(bs_view)
def _encode_longint(i: int) -> bytes:
return struct.pack("<Q", i)
def _decode_longint(i: bytes) -> int:
return struct.unpack("<Q", i)[0]
def _procprint(*a, **kw) -> None:
global _g_print_lock
with _g_print_lock:
print(*a, **kw)
def _maybe_procprint(*a, **kw) -> None:
global _g_verbose
if _g_verbose >= 2:
_procprint()
_procprint(*a, **kw)
def _maybe_print(*a, **kw) -> None:
global _g_verbose
if _g_verbose >= 1:
print(*a, **kw)
def _get_todo_index(m_lock: mp.Lock, m_todo_index: mp.Value) -> Optional[int]:
with m_lock:
if m_todo_index.value < len(_WorkerGlobals.todos):
ret = m_todo_index.value
m_todo_index.value += 1
return ret
def _worker_main(
worker_index: int, m_lock: mp.Lock, m_todo_index: mp.Value, child_conn: Connection
):
global _g_worker_index
global _g_verbose
# You may want to comment this out
warnings.catch_warnings(record=False)
# For debugging inside your target function.
# Example: `print(f'Worker {_g_worker_index} beginning work!')`
_g_worker_index = worker_index
prefix = f"In worker_index={worker_index} --"
worker_comms = _WorkerComms(child_conn)
try:
while True:
todo_index = _get_todo_index(m_lock, m_todo_index)
if todo_index is None:
break
_maybe_procprint(f"{prefix} Starting todo_index={todo_index}...")
target, args, kwargs = _WorkerGlobals.read_assigned_work(todo_index=todo_index)
result = target(*args, **kwargs)
_maybe_procprint(
f"{prefix} Finished todo_index={todo_index}. Sending pickled result to parent..."
)
worker_comms.send_pickled(obj=(todo_index, result))
_maybe_procprint(
f"{prefix} Sent pickled result for todo_index={todo_index}. Looping..."
)
except Exception as e:
_procprint(
f"Exception when executing function in subprocess. Will print traceback. Error: {repr(e)}"
)
traceback.print_exc()
raise
finally:
_maybe_procprint(f"Queue is empty for worker_index={worker_index}. Cleaning up.")
child_conn.close()
@timeit()
def _gather_results(handles: List[_WorkerHandle], desc: Optional[str], n_todos: int) -> List[Any]:
def terminate():
for h in handles:
if h.proc.is_alive():
with contextlib.suppress(ValueError):
h.proc.terminate()
def cleanup():
for h in handles:
h.join_n_close()
try:
return __gather_results(handles=handles, desc=desc, n_todos=n_todos)
except KeyboardInterrupt:
print(
f"In {_gather_results.__qualname__} -- Received KeyboardInterrupt. Terminating workers."
)
terminate()
raise
finally:
cleanup()
def __gather_results(handles: List[_WorkerHandle], desc: Optional[str], n_todos: int) -> List[Any]:
global _g_verbose
prefix = f"In {__gather_results.__qualname__} --"
# The order of `results` will match the order of `handles`' assigned todos
results = [None] * n_todos
if desc:
tq: tqdm = tqdm(total=n_todos, desc=desc)
conn2handle: Dict[Connection, _WorkerHandle] = {h.parent_conn: h for h in handles}
# Loop until all workers' outputs/results have been received
n_finished = 0
while n_finished < n_todos:
# `parent_conn` for every worker with a result ready
ready_conns = mp.connection.wait(list(conn2handle))
# Gather results from finished workers
for conn in ready_conns:
handle = conn2handle[conn]
try:
if conn.poll(0):
todo_index, payload = _WorkerComms(conn=conn).receive_pickled()
results[todo_index] = payload
_maybe_procprint(
f"{prefix} Gathered todo_index={todo_index} from worker_index={handle.index}"
)
if desc:
# noinspection PyUnboundLocalVariable
tq.update(1)
_maybe_procprint("")
n_finished += 1
except EOFError:
continue
except OSError:
continue
if desc:
# Clean up the progress indicator
tq.close()
print() # Take this out at your own risk! >:)
return results
def _create_worker(worker_index: int, m_lock, m_todo_index) -> _WorkerHandle:
parent_conn, child_conn = mp.Pipe(duplex=False)
proc = mp.Process(
target=_worker_main,
kwargs=dict(
worker_index=worker_index,
m_lock=m_lock,
m_todo_index=m_todo_index,
child_conn=child_conn,
),
name=f"processit-{worker_index}",
daemon=False,
)
return _WorkerHandle(
index=worker_index, proc=proc, parent_conn=parent_conn, child_conn=child_conn
)
@timeit()
def _start_workers(n: int, m_lock, m_todo_index) -> List[_WorkerHandle]:
_maybe_print(f"In {_start_workers.__qualname__} -- Starting {n} workers")
handles = [_create_worker(i, m_lock, m_todo_index) for i in range(n)]
for h in handles:
h.proc.start()
h.child_conn.close()
return handles
def _execute_todo(todo: dict, common_kwargs) -> Any:
target = todo["target"]
args = todo.get("args") or ()
kwargs = todo.get("kwargs") or {}
kwargs = {**kwargs, **common_kwargs}
return target(*args, **kwargs)
def _execute_todos_serially(*, chunk_of_todos: List[dict], **common_kwargs) -> List[Any]:
return [_execute_todo(todo, common_kwargs) for todo in chunk_of_todos]
def _chunkify_todos(todos: List[dict], nchunks: int) -> List[dict]:
chunksz: int = get_chunksz(len(todos), nchunks=nchunks)
chunks_of_todos: Iterable[List[dict]] = chunks(todos, chunksz=chunksz)
return [
{"target": _execute_todos_serially, "kwargs": {"chunk_of_todos": chunk_of_todos},}
for chunk_of_todos in chunks_of_todos
]
def chunks(l: Collection, chunksz: int = 0, nchunks: int = 0) -> Iterable:
assert bool(chunksz) ^ bool(nchunks), (chunksz, nchunks)
if chunksz:
return _n_sized_chunks(l=l, n=chunksz)
else:
return _exactly_n_chunks(l=l, n=nchunks)
def _n_sized_chunks(l: Collection, n: int) -> Iterable:
"""Yield successive `n`-sized chunks from `l`."""
for i in range(0, len(l), n):
s = i
e = i + n
yield l[s:e]
def _exactly_n_chunks(l: Collection, n: int) -> Iterable:
"""Yield exactly n chunks from `l`."""
chunksz = int(math.floor(len(l) / n))
chunkszs = [chunksz for _ in range(n)]
for i in range(n):
if sum(chunkszs) < len(l):
chunkszs[i] += 1
s = 0
chunkszs = iter(chunkszs)
while s < len(l):
e = s + next(chunkszs)
yield l[s:e]
s = e
def get_chunksz(n_items_to_process: int, nchunks: int = cpu_count()) -> int:
return int(math.ceil(n_items_to_process / nchunks))
#########
# TESTS #
#########
class _Tests:
@staticmethod
def test_encode_decode_payload_length_header():
obj = [1, "a", None]
payload = pickle5.dumps(obj, protocol=5)
header = _encode_longint(len(payload))
assert len(header) == 8
assert len(payload) == 23 == _decode_longint(header)
@staticmethod
def test_processit_chunked():
import numpy as np
import pandas as pd
njobs = 6
sz = int(1e4)
shape = (sz, sz)
df = pd.DataFrame(np.arange(shape[0] * shape[1]).reshape(shape))
def skew_of_ith_row(i_, df_):
return df_.loc[i_].skew()
normal_time = time.time()
normal = df.skew(axis=1)
normal_time = time.time() - normal_time
serial_time = time.time()
expected = [skew_of_ith_row(i, df) for i in df.index]
serial_time = time.time() - serial_time
parallel_time = time.time()
actual = processit(
[dict(target=skew_of_ith_row, args=(i, df)) for i in df.index], max_nprocs=njobs,
)
parallel_time = time.time() - parallel_time
parallel_chunky_time = time.time()
actual_chunky = processit_chunky(
[dict(target=skew_of_ith_row, args=(i, df)) for i in df.index], max_nprocs=njobs,
)
parallel_chunky_time = time.time() - parallel_chunky_time
times = dict(
normal_time=normal_time,
serial_time=serial_time,
parallel_time=parallel_time,
parallel_chunky_time=parallel_chunky_time,
)
print(times)
assert expected == actual == actual_chunky == list(normal.to_numpy())
assert parallel_chunky_time < parallel_time, times
assert parallel_time < normal_time / 2, times
assert parallel_time < serial_time / 2, times
if __name__ == "__main__":
_Tests.test_encode_decode_payload_length_header()
_Tests.test_processit_chunked()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment