Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@battmatt
Last active March 21, 2024 20:50
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save battmatt/3bd206fba7cd1aa13f37a9c7b88a23a3 to your computer and use it in GitHub Desktop.
Save battmatt/3bd206fba7cd1aa13f37a9c7b88a23a3 to your computer and use it in GitHub Desktop.
Example Parallel Task API based on Celery
# (c) Copyright 2018 Zymergen, Inc.
# All Rights Reserved
"""
The following is example code used for a technology blog post: https://medium.com/@ZymergenTechBlog/building-a-parallel-task-api-with-celery-dbae5ced4e28
The ForkJoin class can be used to generate a ZWork task that contains a single
distributed processing step. Your job should have 3 parts. An initial setup step
responsible for splitting of inputs into workable chunks. A process step that can
process each chunk in a forked execution process and a join step that puts it all
together and returns the final result.
"""
import copy
import functools
from types import GeneratorType
import cerberus
from celery import chain, group, shared_task
# Parallelization settings
PARALLEL_DEFAULT_GROUP_SIZE = 32
class InputValidationError(TypeError):
"""Exception representing a failure to validate job inputs against a Cerberus schema.
Attributes:
inputs: The inputs that failed validation
schema: The Cerberus schema
errors: The list of validation errors
"""
def __init__(self, inputs, schema, errors):
self.inputs = inputs
self.schema = schema
self.errors = errors
self.warnings = None
err_template = "Inputs did not match schema.\n\tInputs: {}\n\tSchema: {}\n\tErrors: {}"
err_msg = err_template.format(self.inputs, self.schema, self.errors)
super(InputValidationError, self).__init__(err_msg)
def chunks(l, n):
"""A utility function for splitting a list into n roughly-equal sized chunks.
"""
return [l[i::n] for i in xrange(n)]
def flatten(xs):
"""Flatten a list.
Turns a nested list into a flat list.
Only goes one layer deep, not fully nested.
Args:
xs: a list of elements and lists
Returns:
A list with those same (and more) elements, but no lists
"""
for x in xs:
if isinstance(x, list) or isinstance(x, GeneratorType):
for elem in x:
yield elem
else:
yield x
def validate_inputs(ins, schema):
"""Validates JSON inputs against a Cerberus schema.
"""
validator = cerberus.Validator()
is_valid = validator.validate(ins, schema)
if not is_valid:
raise InputValidationError(ins, schema, validator.errors)
# def strips_extra_args(fn):
# """Makes a function able to accept arbitrary extra keyword arguments.
#
# Returns a version of fn that can safely receive extra arguments, and raises
# an exception upon incorrect argumentation that is more useful than Python's
# default.
#
# Details omitted
#
# Args:
# fn: any function
# Returns:
# fn, but able to accept extra keyword arguments
# Raises:
# TypeError is raised whenever too few arguments are passed in.
# """
# return fn
class ForkJoinTask(object):
"""Represents a task that can be run in parallel.
To implement your own parallel task:
1. Write a new class, MyTask, that inherits from this one
2. Override the command and job_inputs properties
3. Write business logic functions outside of the MyTask class
- a "setup" for dividing inputs into units of work, each unit is a dictionary
with keys that can be used as arg names for the process step
- a "processer" for performing a single unit of work defined by the setup step
- a "joiner" for recombining results from the work units
4. Decorate those functions with MyTask.setup(), MyTask.process(), and MyTask.join()
- see their docstrings below for details, restrictions, and expectations
5. Input args are saved as 'orig_input_args`, which is dict of the original args
that can be accessed in individual method signatures (e.g., process, join).
Why do I need to implement a new static singleton? Can't this be more dynamic?
No. This limitation comes from Celery, not us.
"""
queue = 'parallel'
def run_without_celery(self, job_args):
"""
Run this task as a job without using the celery backend. This entire task will run in a
single process, and there is no serialization/deserialization to/from JSON between steps,
but otherwise should mimic the celery case exactly.
Args:
job_args: The input dictionary to the job
Returns:
The job results dictionary.
"""
# avoid changing the global state of original job args
nested_job_args = job_args.copy()
nested_job_args['orig_input_args'] = job_args.copy()
setup_results = self.setup_step((), **nested_job_args)
process_results = [self.process_step(setup_results, i, **nested_job_args)
for i in range(len(setup_results))]
return self.join_step(process_results, **nested_job_args)
@classmethod
def get_input_schema(cls, *args, **kwargs):
"""This method provides a Cerberus schema for job input validation.
Returns:
A Cerberus schema that will be used to validate job inputs.
"""
raise NotImplementedError("Tasks MUST implement a schema")
@classmethod
def split(cls, work_units):
"""Override this to explicitly control which workers do which work.
You can almost certainly use this default.
Args:
work_units (list<dict>): The inputs for each possible process() step
This is the output from the setup() step.
Returns:
A list of list<dict>s. Each parent list corresponds to work for one worker.
"""
expected_workers = PARALLEL_DEFAULT_GROUP_SIZE
return chunks(work_units, min(expected_workers, len(work_units)))
@classmethod
def setup(cls, task):
"""Decorates a function that divides job inputs into units of work.
This decorator MUST be used EXACTLY ONCE per job like this:
@MyTask.setup
def divide_job_inputs(arg0_from_schema, arg1_from_schema, ...):
return [{'arg': 'val0'}, {'arg': 'val1'}]
Args (passed into your decorated function):
**kwargs (dict): Corresponds to parsed and validated job inputs.
These should match keys from your jobs' get_input_schema().
Return (expected from your decorated function):
A list of dicts. Each dict represents a subset of inputs needed to do one
unit of parallel work. The keys in this dict should match the params
from your process() step function. Remaining inputs to procces() are
provided by parsed/validated job inputs.
"""
@functools.wraps(task)
def setter_upper(*args, **kwargs):
# We use cerberus for schema validation
schema = strips_extra_args(cls.get_input_schema)(**kwargs)
validate_inputs(kwargs, schema)
normalized_inputs = cerberus.Validator().normalized(kwargs, schema)
result = strips_extra_args(task)(**normalized_inputs)
split_results = cls.split(result)
return split_results
cls.setup_step = lifecycle_task(setter_upper, queue=cls.queue)
return cls.setup_step
@classmethod
def process(cls, task):
"""Decorates a function that performs one unit of parallel work.
This decorator MUST be used EXACTLY ONCE per job like this:
@MyTask.process
def do_parallel_work(param0, param1, ...):
return "whatever I want"
Args (passed into your decorated function):
**kwargs (dict): All the parsed/validated job inputs, PLUS all the
arguments from one of setup()'s outputs. Your function will be
passed only the arguments explicitly claimed in its signature.
In the event of key collision, arguments from one of setup()'s
outputs take precedent over parsed/validated job inputs.
Return (expected from your decorated function):
Your function can return any serializable object.
"""
@functools.wraps(task)
def process_inputs(divided_inputs, group_index, **kwargs):
try:
my_share_of_work = divided_inputs[group_index]
except IndexError:
return []
outs = []
schema = strips_extra_args(cls.get_input_schema)(**kwargs)
normalized_kwargs = cerberus.Validator().normalized(kwargs, schema)
for work_unit in my_share_of_work:
work_unit_kwargs = copy.deepcopy(normalized_kwargs)
work_unit_kwargs.update(work_unit)
result = strips_extra_args(task)(**work_unit_kwargs)
outs.append(result)
return outs
cls.process_step = lifecycle_task(process_inputs, queue=cls.queue)
return cls.process_step
@classmethod
def join(cls, task):
"""Decorates a function that recombines parallel results.
This decorator MUST be used EXACTLY ONCE per job like this:
@MyTask.join
def recombine_parallel_results(parallel_results, **kwargs):
return "whatever I want"
Args (passed into your decorated function):
parallel_results (list): A list of all values that were returned from
an invocation of the process() step.
**kwargs (dict): All the parsed/validated job inputs.
Return (expected from your decorated function):
Any serializable object. This object is the job's final result.
"""
@functools.wraps(task)
def joiner(distributed_results, *args, **kwargs):
schema = strips_extra_args(cls.get_input_schema)(**kwargs)
normalized_kwargs = cerberus.Validator().normalized(kwargs, schema)
# The process step nested its results. Unnest them.
unnested_results = list(flatten(distributed_results))
results = strips_extra_args(task)(unnested_results, *args, **normalized_kwargs)
return results
# place the final step on a different queue so it's not blocked by other running
# parallel jobs
cls.join_step = lifecycle_task(joiner, queue='celery')
return cls.join_step
@classmethod
def signature(cls, options):
"""Returns a Celery signature to describe the complete job-step sequence.
Args:
options (dict): The shared options for the job (auth, request, etc.)
Returns:
A Celery signature.
"""
return fork_join_task(cls.setup_step, cls.process_step, cls.join_step, options)
def fork_join_task(setup_step, process_step, join_step, bound_args):
"""Creates a parallel Celery fork/join task from provided functions.
Args:
setup_step (celery task): A "setup" step for the whole job
process_step (celery task): A "process" step that runs in parallel after setup
join_step (celery task): A "join" step to recombine parallel results
bound_args (dict): Any bound arguments that can be accessed by all steps
Returns:
A new Celery job that performs a setup/process/join work pattern.
The returned job's steps will all be partially applied over bound_args.
"""
setup_sig = setup_step.signature(**bound_args)
process_sig = parallel_processing_step(process_step, bound_args)
join_sig = join_step.signature(**bound_args)
return chain(setup_sig, process_sig, join_sig)
def parallel_processing_step(
process_step, bound_args, group_size=PARALLEL_DEFAULT_GROUP_SIZE):
"""Returns a "group" signature for a distributed application of process_fn.
"""
signatures = [process_step.signature(group_index=i, **bound_args) for i in range(group_size)]
return group(signatures)
def lifecycle_task(task, queue):
"""Makes a Celery task from a Python function.
This runner can act as a single step within a parallel job.
Args:
task (fn): the function to run
queue (str): The name of the Celery queue to use.
Returns:
A Celery task that runs the provided function
"""
name = "{}.{}".format(task.__module__, task.__name__)
@shared_task(ignore_result=False,
name=name, queue=queue, options={'queue': queue})
def internal_runner(*args, **kwargs):
return task(*args, **kwargs)
return internal_runner
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment