Last active
August 15, 2022 08:04
-
-
Save acarapetis/a103661a2bf0789b2c7863c3de8aa870 to your computer and use it in GitHub Desktop.
simple task DAG executor in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED | |
from typing import Iterable, Set | |
from time import sleep | |
class Task: | |
_tasks = {} # job registry: see run_job below | |
def __init__(self, fun, name=None, deps=None, **kwargs): | |
self.do_work = fun | |
self.deps = set(deps or []) | |
self.name = name or fun.__name__ | |
Task._tasks[self.name] = self | |
def cleanup(self, fun): | |
self.do_cleanup = fun | |
@classmethod | |
def get(cls, name): | |
return cls._tasks[name] | |
def task(*args, **kwargs): | |
def decorator(function) -> Task: | |
return Task(function, *args, **kwargs) | |
return decorator | |
@task() | |
def job1(): | |
print("doing job1") | |
sleep(1) | |
print("job1 done") | |
@job1.cleanup | |
def _(): | |
print("cleaning up job1's mess") | |
@task(deps=[job1]) | |
def job2(): | |
print("doing job2") | |
sleep(1) | |
print("job2 done") | |
@job2.cleanup | |
def _(): | |
print("cleaning up job2's mess") | |
@task(deps=[job1]) | |
def job3(): | |
print("doing job3") | |
sleep(2) | |
raise Exception("job3 failed") | |
@job3.cleanup | |
def _(): | |
print("cleaning up job3's mess") | |
@task(deps=[job2, job3]) | |
def job4(): | |
print("doing job4") | |
print("job4 done") | |
@job4.cleanup | |
def _(): | |
print("cleaning up job4's mess") | |
def run_job(job_name: str): | |
# you can't pass functions across process boundaries, so just pass the job name | |
return Task.get(job_name).do_work() | |
def run_dag(jobs: Iterable[Task]): | |
waiting_jobs = set(jobs) | |
running_jobs = dict() | |
finished_jobs = set() | |
pool = ProcessPoolExecutor(max_workers=12) | |
cleanup_stack = [] | |
try: | |
while waiting_jobs or running_jobs: | |
ready = {job for job in waiting_jobs if job.deps.issubset(finished_jobs)} | |
for job in ready: | |
future = pool.submit(run_job, job.name) | |
running_jobs[future] = job | |
cleanup_stack.append(job.do_cleanup) | |
waiting_jobs.discard(job) | |
done, still_running = wait(running_jobs.keys(), return_when=FIRST_COMPLETED) | |
for future in done: | |
exc = future.exception() | |
if exc is not None: | |
raise exc | |
job = running_jobs.pop(future) | |
finished_jobs.add(job) | |
finally: | |
for f in reversed(cleanup_stack): | |
f() | |
def collect_deps(tasks: Iterable[Task]) -> Set[Task]: | |
collected = set() | |
new = set(tasks) | |
while new: | |
collected.update(new) | |
new = {dep for job in new for dep in job.deps} | |
return collected | |
run_dag(collect_deps([job4])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment