Skip to content

Instantly share code, notes, and snippets.

@acarapetis
Last active August 15, 2022 08:04
Show Gist options
  • Save acarapetis/a103661a2bf0789b2c7863c3de8aa870 to your computer and use it in GitHub Desktop.
Save acarapetis/a103661a2bf0789b2c7863c3de8aa870 to your computer and use it in GitHub Desktop.
simple task DAG executor in python
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED
from typing import Iterable, Set
from time import sleep
class Task:
_tasks = {} # job registry: see run_job below
def __init__(self, fun, name=None, deps=None, **kwargs):
self.do_work = fun
self.deps = set(deps or [])
self.name = name or fun.__name__
Task._tasks[self.name] = self
def cleanup(self, fun):
self.do_cleanup = fun
@classmethod
def get(cls, name):
return cls._tasks[name]
def task(*args, **kwargs):
def decorator(function) -> Task:
return Task(function, *args, **kwargs)
return decorator
@task()
def job1():
print("doing job1")
sleep(1)
print("job1 done")
@job1.cleanup
def _():
print("cleaning up job1's mess")
@task(deps=[job1])
def job2():
print("doing job2")
sleep(1)
print("job2 done")
@job2.cleanup
def _():
print("cleaning up job2's mess")
@task(deps=[job1])
def job3():
print("doing job3")
sleep(2)
raise Exception("job3 failed")
@job3.cleanup
def _():
print("cleaning up job3's mess")
@task(deps=[job2, job3])
def job4():
print("doing job4")
print("job4 done")
@job4.cleanup
def _():
print("cleaning up job4's mess")
def run_job(job_name: str):
# you can't pass functions across process boundaries, so just pass the job name
return Task.get(job_name).do_work()
def run_dag(jobs: Iterable[Task]):
waiting_jobs = set(jobs)
running_jobs = dict()
finished_jobs = set()
pool = ProcessPoolExecutor(max_workers=12)
cleanup_stack = []
try:
while waiting_jobs or running_jobs:
ready = {job for job in waiting_jobs if job.deps.issubset(finished_jobs)}
for job in ready:
future = pool.submit(run_job, job.name)
running_jobs[future] = job
cleanup_stack.append(job.do_cleanup)
waiting_jobs.discard(job)
done, still_running = wait(running_jobs.keys(), return_when=FIRST_COMPLETED)
for future in done:
exc = future.exception()
if exc is not None:
raise exc
job = running_jobs.pop(future)
finished_jobs.add(job)
finally:
for f in reversed(cleanup_stack):
f()
def collect_deps(tasks: Iterable[Task]) -> Set[Task]:
collected = set()
new = set(tasks)
while new:
collected.update(new)
new = {dep for job in new for dep in job.deps}
return collected
run_dag(collect_deps([job4]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment