Skip to content

Instantly share code, notes, and snippets.

@omegaml

omegaml/omx_chained.py

Last active Dec 22, 2020
Embed
What would you like to do?
omega|ml plugin to run chained tasks
"""
omegaml plugin to chain runtime tasks
Usage:
# this will chain the fit and predict, i.e. fit will run only on predict success
with om.runtime.chain() as crt:
crt.model('regmodelx').fit('sample[y]', 'sample[x]')
crt.model('regmodelx').predict([5], rName='foox')
result = crt.run()
# sometime later
print(result.get())
Installation:
!pip install -q getgist
!rm -f *omx_chained.py && getgist -y omegaml omx_chained.py
import omx_chained
"""
import omegaml as om
import types
from contextlib import contextmanager
class TaskChain:
def __init__(self):
self.tasks = []
def add(self, task):
self.tasks.append(task)
def delay(self, *args, **kwargs):
return self.apply_async(args=args, kwargs=kwargs)
def apply_async(self, args=None, kwargs=None, **celery_kwargs):
task = self.tasks[-1]
task._apply_kwargs(kwargs, celery_kwargs)
# immutable means results are not passed on from task to task
sig = task.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=True)
self.tasks[-1] = sig
return self
def run(self):
from celery import chain
chained = chain(*self.tasks)
return chained.apply_async()
@contextmanager
def chain(self):
chain = TaskChain()
_orig_task = self.task
def chaining_task(*args, **kwargs):
task = _orig_task(*args, **kwargs)
chain.add(task)
return chain
self.task = chaining_task
chain.runtime = self
chain.runtime.run = chain.run
try:
yield chain.runtime
finally:
chain.runtime.task = _orig_task
chain.runtime.run = None
# attach
from omegaml.runtimes.runtime import OmegaRuntime
OmegaRuntime.chain = chain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment