Last active
October 9, 2023 12:02
-
-
Save jprochazk/fcc3cbebd92e0b526e744d94a4375401 to your computer and use it in GitHub Desktop.
python parallel DAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import time | |
from concurrent.futures import ThreadPoolExecutor | |
from multiprocessing import Event, cpu_count | |
from multiprocessing.synchronize import Event as EventClass | |
from queue import Empty, Queue | |
from typing import Callable, Generic, Hashable, TypeVar | |
T = TypeVar("T", bound=Hashable) | |
class Node(Generic[T]): | |
def __init__(self, value: T): | |
self.value = value | |
self.counter: int = 0 | |
self.dependents: list[Node[T]] = [] | |
class ParallelDAG(Generic[T]): | |
def __init__(self, dependency_graph: dict[T, list[T]]): | |
self._nodes: dict[T, Node[T]] = {} | |
self._queue: list[T] = [] | |
self._num_finished: int = 0 | |
for node, deps in dependency_graph.items(): | |
new_node = self._get_or_insert(node) | |
new_node.counter += len(deps) | |
for dep in deps: | |
self._get_or_insert(dep).dependents.append(new_node) | |
self._queue.extend(node.value for node in self._nodes.values() if node.counter == 0) | |
def _get_or_insert(self, node: T) -> Node[T]: | |
if node not in self._nodes: | |
self._nodes[node] = Node(node) | |
return self._nodes[node] | |
def _finish(self, node: T) -> None: | |
for dependent in self._nodes[node].dependents: | |
dependent.counter -= 1 | |
if dependent.counter == 0: | |
self._queue.append(dependent.value) | |
self._num_finished += 1 | |
def _is_done(self) -> bool: | |
return self._num_finished == len(self._nodes) | |
def walk(self, f: Callable[[T], None], max_tokens: int, refill_interval_s: int) -> None: | |
num_cpus = cpu_count() | |
num_workers = num_cpus - 1 if num_cpus > 0 else 1 | |
with ThreadPoolExecutor(max_workers=num_workers) as p: | |
task_queue: Queue[T] = Queue() | |
done_queue: Queue[T] = Queue() | |
shutdown: EventClass = Event() | |
def worker(n: int) -> None: | |
while not shutdown.is_set(): | |
try: | |
node = task_queue.get_nowait() | |
except Empty: | |
time.sleep(0) | |
continue | |
f(node) | |
done_queue.put(node) | |
for n in range(0, num_workers): | |
p.submit(worker, n) | |
tokens = max_tokens | |
last_refill = time.time() | |
while not self._is_done(): | |
while len(self._queue) > 0: | |
now = time.time() | |
if now - last_refill > refill_interval_s: | |
tokens = max_tokens | |
last_refill = now | |
if tokens == 0: | |
break | |
tokens -= 1 | |
task_queue.put(self._queue.pop()) | |
self._finish(done_queue.get()) | |
shutdown.set() | |
def main() -> None: | |
def process(node: str) -> None: | |
time.sleep(1) | |
print(f"processed {node}") | |
dag = ParallelDAG( | |
{ | |
"A": [], | |
"B": [], | |
"C": [], | |
"D": ["A", "B", "C"], | |
} | |
) | |
dag.walk( | |
process, | |
max_tokens=2, | |
refill_interval_s=1, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment