Skip to content

Instantly share code, notes, and snippets.

@jprochazk
Last active October 9, 2023 12:02
Show Gist options
  • Save jprochazk/fcc3cbebd92e0b526e744d94a4375401 to your computer and use it in GitHub Desktop.
Save jprochazk/fcc3cbebd92e0b526e744d94a4375401 to your computer and use it in GitHub Desktop.
python parallel DAG
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Event, cpu_count
from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Queue
from typing import Callable, Generic, Hashable, TypeVar
T = TypeVar("T", bound=Hashable)
class Node(Generic[T]):
def __init__(self, value: T):
self.value = value
self.counter: int = 0
self.dependents: list[Node[T]] = []
class ParallelDAG(Generic[T]):
def __init__(self, dependency_graph: dict[T, list[T]]):
self._nodes: dict[T, Node[T]] = {}
self._queue: list[T] = []
self._num_finished: int = 0
for node, deps in dependency_graph.items():
new_node = self._get_or_insert(node)
new_node.counter += len(deps)
for dep in deps:
self._get_or_insert(dep).dependents.append(new_node)
self._queue.extend(node.value for node in self._nodes.values() if node.counter == 0)
def _get_or_insert(self, node: T) -> Node[T]:
if node not in self._nodes:
self._nodes[node] = Node(node)
return self._nodes[node]
def _finish(self, node: T) -> None:
for dependent in self._nodes[node].dependents:
dependent.counter -= 1
if dependent.counter == 0:
self._queue.append(dependent.value)
self._num_finished += 1
def _is_done(self) -> bool:
return self._num_finished == len(self._nodes)
def walk(self, f: Callable[[T], None], max_tokens: int, refill_interval_s: int) -> None:
num_cpus = cpu_count()
num_workers = num_cpus - 1 if num_cpus > 0 else 1
with ThreadPoolExecutor(max_workers=num_workers) as p:
task_queue: Queue[T] = Queue()
done_queue: Queue[T] = Queue()
shutdown: EventClass = Event()
def worker(n: int) -> None:
while not shutdown.is_set():
try:
node = task_queue.get_nowait()
except Empty:
time.sleep(0)
continue
f(node)
done_queue.put(node)
for n in range(0, num_workers):
p.submit(worker, n)
tokens = max_tokens
last_refill = time.time()
while not self._is_done():
while len(self._queue) > 0:
now = time.time()
if now - last_refill > refill_interval_s:
tokens = max_tokens
last_refill = now
if tokens == 0:
break
tokens -= 1
task_queue.put(self._queue.pop())
self._finish(done_queue.get())
shutdown.set()
def main() -> None:
def process(node: str) -> None:
time.sleep(1)
print(f"processed {node}")
dag = ParallelDAG(
{
"A": [],
"B": [],
"C": [],
"D": ["A", "B", "C"],
}
)
dag.walk(
process,
max_tokens=2,
refill_interval_s=1,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment