Last active
December 12, 2018 10:37
-
-
Save thulasi-ram/3c166f587a3df69af8aaccb4f2449db5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A poor man's implementation of celery like async task manager. | |
Hacked in under 2 hours. | |
Author: Thulasi | |
Usage: | |
app = Flask(__name__) | |
tasker = Tasker(app, rabbitmq_params={'hostname': 'amqp://guest:guest@localhost:5672/reseller'}) | |
or | |
tasker = Tasker(app, rabbitmq_params=app.config['RABBITMQ_CONFIG']) | |
@tasker.task | |
def long_process(key, value): | |
pass | |
long_process(key, value) # executes synchronously | |
long_process.defer(key, value) # executes asynchronously | |
""" | |
import logging | |
from typing import Dict | |
from kombu import Queue, Exchange, Connection, connections, producers, uuid | |
from kombu.mixins import ConsumerMixin | |
logger = logging.getLogger(__name__) | |
exchange = Exchange('tasker', 'topic', durable=True) | |
class Tasker: | |
registry = {} | |
def __init__(self, app, rabbitmq_params: Dict): | |
self.app = app | |
rabbitmq_params['transport_options'] = {'confirm_publish': True} | |
self.connection = Connection(**rabbitmq_params) | |
@app.cli.command() | |
def run_tasker(): | |
with self.connection as conn: | |
worker = Worker(connection=conn, callback=self.callback) | |
worker.run() | |
def task(self, func, unique_name=''): | |
task_name = unique_name if unique_name else func.__qualname__ | |
def defer(*args, **kwargs): | |
data = { | |
'task_name': task_name, | |
'args': args, | |
'kwargs': kwargs, | |
} | |
self._register(func, task_name) | |
task_id = publish(self.connection, routing_key='task.#', data=data) | |
return task_id | |
self._register(func, task_name) | |
func.defer = defer | |
return func | |
def callback(self, body): | |
func_name = body['task_name'] | |
args = body['args'] | |
kwargs = body['kwargs'] | |
self.registry[func_name](*args, **kwargs) | |
def _register(self, func, task_name): | |
if task_name in self.registry: | |
if not self.registry[task_name] == func: | |
raise RuntimeError('Duplicate task received with same name. Use @task(unique_name=...)') | |
self.registry[task_name] = func | |
def publish(connection, routing_key, data, unique_id=None): | |
unique_id = unique_id or uuid() | |
with connections[connection].acquire(block=True, timeout=300) as conn: | |
with producers[conn].acquire(block=True, timeout=30) as producer: | |
logger.info(f'Publishing message {unique_id} with {data}') | |
if 'errors' not in data: | |
data['errors'] = [] | |
producer.publish( | |
exchange=exchange, | |
body=data, | |
routing_key=routing_key, | |
declare=[exchange], | |
message_id=unique_id, | |
) | |
return unique_id | |
class Worker(ConsumerMixin): | |
queue = Queue('tasks', exchange, 'task.#') | |
dead_queue = Queue('dead-tasks', exchange, 'dead.task.#') | |
def __init__(self, connection, callback): | |
self.connection = connection | |
self.callback = callback | |
self.dead_queue.maybe_bind(self.connection.channel()) | |
self.dead_queue.declare() | |
def get_consumers(self, Consumer, channel): | |
return [Consumer(queues=[self.queue], callbacks=[self.on_task])] | |
def on_task(self, body, message): | |
unique_id = message.properties.get('message_id') | |
if not unique_id: | |
unique_id = uuid() | |
message.properties['message_id'] = unique_id | |
logger.info(f'Got message with task_id: {unique_id}') | |
try: | |
self.callback(body) | |
except Exception as e: | |
logger.exception(e) | |
body['errors'].append(repr(e)) | |
publish( | |
self.connection, | |
routing_key='dead.task', | |
data=body, | |
unique_id=unique_id, | |
) | |
finally: | |
message.ack() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment