Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@ymorgenstern
Last active July 11, 2022 08:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ymorgenstern/bb09baee174ea37e1ea5a6b1e7655981 to your computer and use it in GitHub Desktop.
Save ymorgenstern/bb09baee174ea37e1ea5a6b1e7655981 to your computer and use it in GitHub Desktop.
Testing complex pipelines in Celery
from time import sleep
from typing import List
from celery import Celery, chain, group, shared_task
from celery.result import AsyncResult, GroupResult
class JobStatus(str, Enum):
PENDING = "pending"
STARTED = "started"
FINISHED = "finished"
FAILURE = "failure"
CANCELLED = "cancelled"
class Job():
id: str
status: JobStatus = JobStatus.PENDING
result: Any
@classmethod
def generate_id(cls) -> str:
return "I'm a job ID"
@classmethod
def from_group_result(cls, group_result: GroupResult) -> "Job":
child_results = [child.result for child in group_result.children]
child_statuses = [child.status for child in group_result.children]
job_status = JobStatus.PENDING
if states.STARTED in child_statuses:
job_status = JobStatus.STARTED
if all(status == states.FAILURE for status in child_statuses):
job_status = JobStatus.FAILURE
elif all(status in states.READY_STATES for status in child_statuses):
job_status = JobStatus.FINISHED
return Job(id=group_result.id, status=job_status, result=child_results)
@shared_task
def do_A(steps: List[int]) -> List[int]:
steps.append(1)
return steps
@shared_task
def do_B(steps: List[int]) -> List[int]:
steps.append(2)
sleep(0.3) # so we know it will be pending unless we wait for it explicitly
return steps
def test_chain(celery_app: Celery, celery_worker): # pylint: disable=unused-argument
result: AsyncResult = chain(do_A.s([]), do_B.s()).apply_async()
assert not result.result # applied asynchronously, not finished yet!
assert result.get() == [1, 2] # Here we wait for it to finish
assert result.result # And then the result is available forever
def test_job_from_multi_chains_finished(
celery_app: Celery, celery_worker
): # pylint: disable=unused-argument
chain_group = group(
chain(do_A.s([]), do_B.s()), chain(do_A.s([]), do_B.s())
)
group_result: GroupResult = chain_group.apply_async(task_id=Job.generate_id())
assert group_result.get() == [[1, 2], [1, 2]]
assert Job.from_group_result(group_result) == Job(
id=group_result.id, status=JobStatus.FINISHED, result=[[1, 2], [1, 2]]
)
def test_job_from_chain_to_group(
celery_app: Celery, celery_worker
): # pylint: disable=unused-argument
score_gene_group = group(do_B.s(), do_B.s())
score_gene_group.options["task_id"] = Job.generate_id()
chain_to_group = chain(do_A.s([]), score_gene_group)
group_result: GroupResult = chain_to_group.apply_async()
assert group_result.get() == [[1, 2], [1, 2]]
assert Job.from_group_result(group_result) == Job(
id=group_result.id, status=JobStatus.FINISHED, result=[[1, 2], [1, 2]]
)
def test_group_error(
celery_app: Celery, celery_worker
): # pylint: disable=unused-argument
error_group = group(
do_A.s(1),
do_A.s(1), # Bad values lead to error, since int has no 'append'
)
group_result: GroupResult = error_group.apply_async(task_id=Job.generate_id())
group_result.get(propagate=False)
job = Job.from_group_result(group_result)
assert job.status == JobStatus.FAILURE
assert job.id == group_result.id
assert len(job.result) == 2
assert isinstance(job.result[0], AttributeError)
assert isinstance(job.result[1], AttributeError)
def test_group_partial_error(
celery_app: Celery, celery_worker
): # pylint: disable=unused-argument
error_group = group(do_A.s(1), do_A.s([]))
group_result: GroupResult = error_group.apply_async(task_id=Job.generate_id())
assert len(group_result.children) == 2
group_result.get(propagate=False)
job = Job.from_group_result(group_result)
assert job.status == JobStatus.FINISHED
assert job.id == group_result.id
assert len(job.result) == 2
assert isinstance(job.result[0], AttributeError)
assert job.result[1] == [1]
def test_job_from_multi_chains_partial_pending(
celery_app: Celery, celery_worker
): # pylint: disable=unused-argument
chain_group = group(
chain(do_A.s([]), do_B.s()), chain(do_A.s([]), do_B.s())
)
group_result: GroupResult = chain_group.apply_async(task_id=Job.generate_id())
assert group_result.children[0].get() == [1, 2] # Wait for one to finish
job = Job.from_group_result(group_result)
assert job == Job(
id=group_result.id, status=JobStatus.PENDING, result=[[1, 2], None]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment