Skip to content

Instantly share code, notes, and snippets.

@linar-jether
Last active March 16, 2022 19:16
Show Gist options
  • Save linar-jether/9783551d201701b1b4da9476e3523095 to your computer and use it in GitHub Desktop.
Save linar-jether/9783551d201701b1b4da9476e3523095 to your computer and use it in GitHub Desktop.
Dynamic celery tasks - remote execution of arbitrary callables and DAGs, using dill to serialize and send executable code to worker. This also shows a way to map an iterable returned from one task to a group of tasks (distributed map), with an optional reducer (chord) to be executed when the group tasks complete
# Task primitives, allows pipeline execution using celery
@app.task
def dmap(it, callback, final=None):
# Map a callback over an iterator and return as a group
callback = subtask(callback)
# Hack for mapping a chain to values, due to a bug where args are not copied in group creation
if isinstance(callback, chain):
if final:
raise ValueError('task_processor: Cannot run reducer for dmap excecuted with a chain.')
return [callback.delay(arg) for arg in it]
run_in_parallel = group(callback.clone([arg, ]) for arg in it)
if len(run_in_parallel.tasks) == 0:
return []
if final:
return chord(run_in_parallel)(final)
return run_in_parallel.delay()
@app.task
def _exec_async(*args, **kwargs):
import dill
import zlib
serialized_callable = kwargs.pop('_serialized_callable', None)
if serialized_callable is None:
raise Exception('Missing serialized_callable kwarg, must contain the dill-serialized callable.')
# Maybe compressed
try:
serialized_callable = serialized_callable.decode('zlib')
except zlib.error:
pass
return dill.loads(serialized_callable)(*args, **kwargs)
def exec_async(func, queue=None, *args, **kwargs):
import dill
queue = queue
if queue is None:
raise ValueError("Missing queue name")
serialized_callable = dill.dumps(func)
kwargs['_serialized_callable'] = serialized_callable
args_ = list(args)
if kwargs.pop('_as_sig', False):
return _exec_async.s(*args_, **kwargs).set(queue=queue)
return _exec_async.apply_async(args_, kwargs, queue=queue, compression='zlib')
def async(func, queue=None, *args, **kwargs):
kwargs['_as_sig'] = True
return exec_async(func, queue, *args, **kwargs)
# DAG components
def get_list_of_items():
return range(100)
def do_something_for_item(item):
import numpy as np
return item * np.random.rand()
def get_top_10(items):
import pandas as pd
return pd.Series(items).sort_values()[-10:]
queue_name = 'test_queue'
dag = async(get_list_of_items, queue_name) | async(dmap, queue_name, callback=async(do_something_for_item, queue_name), final=async(get_top_10, queue_name))
# Execute DAG
dag.delay().get().get()
# Execute simple lambda
async_res = exec_async(lambda x: x*x, queue_name, 5)
async_res.get()
@linar-jether
Copy link
Author

Might be better to use cloudpickle instead of dill here since it handles external imports better

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment