Skip to content

Instantly share code, notes, and snippets.

@Xion
Created November 4, 2015 06:39
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Xion/46d739b980af14a0d3ea to your computer and use it in GitHub Desktop.
Save Xion/46d739b980af14a0d3ea to your computer and use it in GitHub Desktop.
Base class for Celery tasks running inside Flask request context
from celery import Task
from flask import has_request_context, make_response, request
from myapp import app
__all__ = ['RequestContextTask']
class RequestContextTask(Task):
"""Base class for tasks that originate from Flask request handlers
and carry over most of the request context data.
This has an advantage of being able to access all the usual information
that the HTTP request has and use them within the task. Pontential
use cases include e.g. formatting URLs for external use in emails sent
by tasks.
"""
abstract = True
#: Name of the additional parameter passed to tasks
#: that contains information about the original Flask request context.
CONTEXT_ARG_NAME = '_flask_request_context'
def __call__(self, *args, **kwargs):
"""Execute task code with given arguments."""
call = lambda: super(RequestContextTask, self).__call__(*args, **kwargs)
context = kwargs.pop(self.CONTEXT_ARG_NAME, None)
if context is None or has_request_context():
return call()
with app.test_request_context(**context):
result = call()
# process a fake "Response" so that
# ``@after_request`` hooks are executed
app.process_response(make_response(result or ''))
return result
def apply_async(self, args=None, kwargs=None, **rest):
if rest.pop('with_request_context', True):
self._include_request_context(kwargs)
return super(RequestContextTask, self).apply_async(args, kwargs, **rest)
def apply(self, args=None, kwargs=None, **rest):
if rest.pop('with_request_context', True):
self._include_request_context(kwargs)
return super(RequestContextTask, self).apply(args, kwargs, **rest)
def retry(self, args=None, kwargs=None, **rest):
if rest.pop('with_request_context', True):
self._include_request_context(kwargs)
return super(RequestContextTask, self).retry(args, kwargs, **rest)
def _include_request_context(self, kwargs):
"""Includes all the information about current Flask request context
as an additional argument to the task.
"""
if not has_request_context():
return
# keys correspond to arguments of :meth:`Flask.test_request_context`
context = {
'path': request.path,
'base_url': request.url_root,
'method': request.method,
'headers': dict(request.headers),
}
if '?' in request.url:
context['query_string'] = request.url[(request.url.find('?') + 1):]
kwargs[self.CONTEXT_ARG_NAME] = context
@systemime
Copy link

systemime commented Aug 17, 2022

The scheme doesn't work in the new version of celery.

With reference to your example, a solution is implemented.

import json
import os
from datetime import timedelta
from typing import Dict, Union
from uuid import uuid1

from celery import Celery, Task
from flask import g, has_app_context, has_request_context, request


class ContextTask(Task):

    def dump_current_request_context(self, task_id) -> Dict:
        if not has_request_context():
            return None

        user = getattr(g, "flask_httpauth_user", None)
        context_info = {
            "method": request.method,
            "headers": dict(request.headers),
            "user_id": user.id if user else None,
        }

        redis = self.app.flask_app.extensions["redis"]
        db = self.app.flask_app.extensions["sqlalchemy"]
        # just spare
        redis.set(f"TASK_DUMP_REQUEST_CONTEXT::{task_id}", json.dumps(context_info))

        return context_info

    def apply_async(self, args=None, kwargs=None, **rest):
        task_id = self.request.id or uuid1().hex
        # Custom data will be passed with headers
        rest.update(
            {"headers": self.dump_current_request_context(task_id), "task_id": task_id}
        )
        return super().apply_async(args, kwargs, **rest)

    def apply(self, args=None, kwargs=None, **rest):
        task_id = self.request.id or uuid1().hex
        rest.update(
            {"headers": self.dump_current_request_context(task_id), "task_id": task_id}
        )
        return super().apply(args, kwargs, **rest)

    def retry(self, args=None, kwargs=None, **rest):
        task_id = self.request.id or uuid1().hex
        rest.update(
            {"headers": self.dump_current_request_context(task_id), "task_id": task_id}
        )
        return super().retry(args, kwargs, **rest)

    def __call__(self, *args, **kwargs):

        if has_app_context():
            os.environ["FLASK_CONTEXT_IN_CELERY"] = "true"
            return Task.__call__(self, *args, **kwargs)
        with self.app.flask_app.app_context():
            os.environ["FLASK_CONTEXT_IN_CELERY"] = "true"
            return Task.__call__(self, *args, **kwargs)

def create_celery_app(flask_app, config=None):
    celery = Celery(flask_app.import_name or celery_config.celery_name or "default")
    celery.autodiscover_tasks()

    celery.Task = ContextTask

    if not hasattr(celery, "flask_app"):
        celery.flask_app = flask_app

    return celery

@celery_app.task(bind=True, name="task_debug")
def task_debug(self, *args, **kwargs):
    # Here you can get custom data
    task_logger.info(self.request.headers)
    task_logger.info(f"task id: {self.request.id}, args: {args!r}, kwargs: {kwargs!r}")

    return ok

Thanks for the idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment