Skip to content

Instantly share code, notes, and snippets.

@pveierland
Forked from alexgarel/celery_test_util.py
Created September 30, 2022 09:28
Show Gist options
  • Save pveierland/029892f3d796afea4f429d1d762f7d8e to your computer and use it in GitHub Desktop.
Save pveierland/029892f3d796afea4f429d1d762f7d8e to your computer and use it in GitHub Desktop.
Helper for tests with celery, running tasks in main thread but keeping control on when (alternative to eager)
from functools import partial
from celery.app.task import Task
from celery.app.utils import find_app
class CeleryTestTask(object):
"""A context manager to patch task in order to queue delayed tasks
and eventually run them. This is for tests.
You may access tasks and done_tasks in your test to verify which tasks
was triggered / runned.
This class is useful in case using eg. transaction hooks and atomic requests.
Simple usage, to run task sequentialy at end of request in a functional test::
with CeleryTestTask(apps="my_app"):
self.client.post("/my-url", {"foo": "bar"})
"""
def __init__(self, apps, run=True):
"""
:param list apps: name of celery apps to catch,
you can use a string for a single app
:param run: run queued tasks on exit.
Note that if tasks launch new tasks they will be queued and played.
"""
if isinstance(apps, str):
apps = [apps]
self.apps = apps
self.run = run
self.tasks = []
self.done_tasks = []
self.orig_apply_async = {}
def queue_task(self, task_class, args, kwargs):
self.tasks.append((task_class, args, kwargs))
def __enter__(self):
# patch task
self.task_apply_async_orig = Task.apply_async
def apply_async_patch(task, args=None, kwargs=None, **options):
self.queue_task(task, args, kwargs)
# patch all tasks
for app in self.apps:
for name, task in find_app(app).tasks.items():
task_class = task.__class__
self.orig_apply_async.setdefault(task_class, task.apply_async)
task_class.apply_async = partial(apply_async_patch, task)
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
if self.run:
while self.tasks:
task, args, kwargs = self.tasks.pop(0)
task(*args, **kwargs)
self.done_tasks.append((task, args, kwargs))
# unpatch
for task_class, orig_apply_async in self.orig_apply_async.items():
task_class.apply_async = orig_apply_async
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment