Last active
July 11, 2022 08:07
-
-
Save ymorgenstern/bb09baee174ea37e1ea5a6b1e7655981 to your computer and use it in GitHub Desktop.
Testing complex pipelines in Celery
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import sleep | |
from typing import List | |
from celery import Celery, chain, group, shared_task | |
from celery.result import AsyncResult, GroupResult | |
class JobStatus(str, Enum): | |
PENDING = "pending" | |
STARTED = "started" | |
FINISHED = "finished" | |
FAILURE = "failure" | |
CANCELLED = "cancelled" | |
class Job(): | |
id: str | |
status: JobStatus = JobStatus.PENDING | |
result: Any | |
@classmethod | |
def generate_id(cls) -> str: | |
return "I'm a job ID" | |
@classmethod | |
def from_group_result(cls, group_result: GroupResult) -> "Job": | |
child_results = [child.result for child in group_result.children] | |
child_statuses = [child.status for child in group_result.children] | |
job_status = JobStatus.PENDING | |
if states.STARTED in child_statuses: | |
job_status = JobStatus.STARTED | |
if all(status == states.FAILURE for status in child_statuses): | |
job_status = JobStatus.FAILURE | |
elif all(status in states.READY_STATES for status in child_statuses): | |
job_status = JobStatus.FINISHED | |
return Job(id=group_result.id, status=job_status, result=child_results) | |
@shared_task | |
def do_A(steps: List[int]) -> List[int]: | |
steps.append(1) | |
return steps | |
@shared_task | |
def do_B(steps: List[int]) -> List[int]: | |
steps.append(2) | |
sleep(0.3) # so we know it will be pending unless we wait for it explicitly | |
return steps | |
def test_chain(celery_app: Celery, celery_worker): # pylint: disable=unused-argument | |
result: AsyncResult = chain(do_A.s([]), do_B.s()).apply_async() | |
assert not result.result # applied asynchronously, not finished yet! | |
assert result.get() == [1, 2] # Here we wait for it to finish | |
assert result.result # And then the result is available forever | |
def test_job_from_multi_chains_finished( | |
celery_app: Celery, celery_worker | |
): # pylint: disable=unused-argument | |
chain_group = group( | |
chain(do_A.s([]), do_B.s()), chain(do_A.s([]), do_B.s()) | |
) | |
group_result: GroupResult = chain_group.apply_async(task_id=Job.generate_id()) | |
assert group_result.get() == [[1, 2], [1, 2]] | |
assert Job.from_group_result(group_result) == Job( | |
id=group_result.id, status=JobStatus.FINISHED, result=[[1, 2], [1, 2]] | |
) | |
def test_job_from_chain_to_group( | |
celery_app: Celery, celery_worker | |
): # pylint: disable=unused-argument | |
score_gene_group = group(do_B.s(), do_B.s()) | |
score_gene_group.options["task_id"] = Job.generate_id() | |
chain_to_group = chain(do_A.s([]), score_gene_group) | |
group_result: GroupResult = chain_to_group.apply_async() | |
assert group_result.get() == [[1, 2], [1, 2]] | |
assert Job.from_group_result(group_result) == Job( | |
id=group_result.id, status=JobStatus.FINISHED, result=[[1, 2], [1, 2]] | |
) | |
def test_group_error( | |
celery_app: Celery, celery_worker | |
): # pylint: disable=unused-argument | |
error_group = group( | |
do_A.s(1), | |
do_A.s(1), # Bad values lead to error, since int has no 'append' | |
) | |
group_result: GroupResult = error_group.apply_async(task_id=Job.generate_id()) | |
group_result.get(propagate=False) | |
job = Job.from_group_result(group_result) | |
assert job.status == JobStatus.FAILURE | |
assert job.id == group_result.id | |
assert len(job.result) == 2 | |
assert isinstance(job.result[0], AttributeError) | |
assert isinstance(job.result[1], AttributeError) | |
def test_group_partial_error( | |
celery_app: Celery, celery_worker | |
): # pylint: disable=unused-argument | |
error_group = group(do_A.s(1), do_A.s([])) | |
group_result: GroupResult = error_group.apply_async(task_id=Job.generate_id()) | |
assert len(group_result.children) == 2 | |
group_result.get(propagate=False) | |
job = Job.from_group_result(group_result) | |
assert job.status == JobStatus.FINISHED | |
assert job.id == group_result.id | |
assert len(job.result) == 2 | |
assert isinstance(job.result[0], AttributeError) | |
assert job.result[1] == [1] | |
def test_job_from_multi_chains_partial_pending( | |
celery_app: Celery, celery_worker | |
): # pylint: disable=unused-argument | |
chain_group = group( | |
chain(do_A.s([]), do_B.s()), chain(do_A.s([]), do_B.s()) | |
) | |
group_result: GroupResult = chain_group.apply_async(task_id=Job.generate_id()) | |
assert group_result.children[0].get() == [1, 2] # Wait for one to finish | |
job = Job.from_group_result(group_result) | |
assert job == Job( | |
id=group_result.id, status=JobStatus.PENDING, result=[[1, 2], None] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment