Last active
December 22, 2020 15:22
-
-
Save omegaml/c5a8cb3b08fcc38d7dca34ed272f7528 to your computer and use it in GitHub Desktop.
omega|ml plugin to run chained tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
omegaml plugin to chain runtime tasks | |
Usage: | |
# this will chain the fit and predict, i.e. fit will run only on predict success | |
with om.runtime.chain() as crt: | |
crt.model('regmodelx').fit('sample[y]', 'sample[x]') | |
crt.model('regmodelx').predict([5], rName='foox') | |
result = crt.run() | |
# sometime later | |
print(result.get()) | |
Installation: | |
!pip install -q getgist | |
!rm -f *omx_chained.py && getgist -y omegaml omx_chained.py | |
import omx_chained | |
""" | |
import omegaml as om | |
import types | |
from contextlib import contextmanager | |
class TaskChain: | |
def __init__(self): | |
self.tasks = [] | |
def add(self, task): | |
self.tasks.append(task) | |
def delay(self, *args, **kwargs): | |
return self.apply_async(args=args, kwargs=kwargs) | |
def apply_async(self, args=None, kwargs=None, **celery_kwargs): | |
task = self.tasks[-1] | |
task._apply_kwargs(kwargs, celery_kwargs) | |
# immutable means results are not passed on from task to task | |
sig = task.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=True) | |
self.tasks[-1] = sig | |
return self | |
def run(self): | |
from celery import chain | |
chained = chain(*self.tasks) | |
return chained.apply_async() | |
@contextmanager | |
def chain(self): | |
chain = TaskChain() | |
_orig_task = self.task | |
def chaining_task(*args, **kwargs): | |
task = _orig_task(*args, **kwargs) | |
chain.add(task) | |
return chain | |
self.task = chaining_task | |
chain.runtime = self | |
chain.runtime.run = chain.run | |
try: | |
yield chain.runtime | |
finally: | |
chain.runtime.task = _orig_task | |
chain.runtime.run = None | |
# attach | |
from omegaml.runtimes.runtime import OmegaRuntime | |
OmegaRuntime.chain = chain | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment