Skip to content

Instantly share code, notes, and snippets.

@liverpoolpjy
Last active March 27, 2024 18:45
Show Gist options
  • Save liverpoolpjy/49cb0e0948053606dc40e350062a97d2 to your computer and use it in GitHub Desktop.
Save liverpoolpjy/49cb0e0948053606dc40e350062a97d2 to your computer and use it in GitHub Desktop.
celery_with_request_context.py
from flask_celery import Celery
from flask import has_request_context, make_response, request
class CeleryWithContext(Celery):
def init_app(self, app):
super(MyCelery, self).init_app(app)
task_base = self.Task
class ContextTask(task_base):
#: Name of the additional parameter passed to tasks
#: that contains information about the original Flask request context.
CONTEXT_ARG_NAME = '_flask_request_context'
def __call__(self, *args, **kwargs):
"""Execute task code with given arguments."""
call = lambda: super(ContextTask, self).__call__(*args, **kwargs)
context = kwargs.pop(self.CONTEXT_ARG_NAME, None)
if context is None or has_request_context():
return call()
with app.test_request_context(**context):
result = call()
# process a fake "Response" so that
# ``@after_request`` hooks are executed
app.process_response(make_response(result or ''))
return result
def apply_async(self, args=None, kwargs=None, **rest):
self._include_request_context(kwargs)
return super(ContextTask, self).apply_async(args, kwargs, **rest)
def apply(self, args=None, kwargs=None, **rest):
self._include_request_context(kwargs)
return super(ContextTask, self).apply(args, kwargs, **rest)
def retry(self, args=None, kwargs=None, **rest):
self._include_request_context(kwargs)
return super(ContextTask, self).retry(args, kwargs, **rest)
def _include_request_context(self, kwargs):
"""Includes all the information about current Flask request context
as an additional argument to the task.
"""
if not has_request_context():
return
# keys correspond to arguments of :meth:`Flask.test_request_context`
context = {
'path': request.path,
'base_url': request.url_root,
'method': request.method,
'headers': dict(request.headers),
'data': request.form
}
if '?' in request.url:
context['query_string'] = request.url[(request.url.find('?') + 1):]
kwargs[self.CONTEXT_ARG_NAME] = context
setattr(ContextTask, 'abstract', True)
setattr(self, 'Task', ContextTask)
celery = CeleryWithContext()
@jackjiali
Copy link

兄弟,管用否?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment