Skip to content

Instantly share code, notes, and snippets.

@TMaYaD
Created April 5, 2019 05:36
Show Gist options
  • Save TMaYaD/1c5e0898179699bdc1885ff05cbea552 to your computer and use it in GitHub Desktop.
Save TMaYaD/1c5e0898179699bdc1885ff05cbea552 to your computer and use it in GitHub Desktop.
luigi include_params_from
_all_params_ = object()
class include_params_from(object):
"""
Task parameter copy.
It also takes two optional parameters: `only` and `omit`.
If `only` is provided, only the parameters given in `only` will be copied.
If `omit` is provided, parameters given in `omit` will not be copied.
Usage:
.. code-block:: python
class TestTask(SawyerTask):
foo = Parameter()
bar = Parameter()
@include_params_from(TestTask, omit=['foo']):
class MyFirstTask(luigi.Task):
'''
This will contain all params of `TestTask` except 'foo'.
'''
def run(self):
print self.bar # this will be defined
print self.foo # this will raise error
# ...
@include_params_from(TestTask, only=['foo']):
class MySecondTask(luigi.Task):
'''
This will copy only 'foo' from `TestTask`.
'''
def run(self):
print self.foo # this will be defined
print self.bar # this will raise error.
# ...
"""
def __init__(self, *sources, only=_all_params_, omit=[]):
if not sources:
raise TypeError("sources cannot be empty")
def params_to_include_from(task):
task_params = set(p[0] for p in task.get_params())
if only is not _all_params_:
task_params = set.intersection(task_params, set(only))
return task_params - set(omit)
super().__init__()
self.params_to_include = {task: params_to_include_from(task) for task in sources}
def __call__(self, target):
"""
Returns `target` with all the parameters which are to be
included from other tasks.
"""
self.copy_params(target)
self.add_clone_method(target)
return target
def copy_params(self, target):
existing_params = set(p[0] for p in target.get_params())
# Get all parameter objects from each of the underlying tasks
for source, params_to_include in self.params_to_include.items():
params_to_copy = params_to_include - existing_params
# Store the set of params for future reference, ex., by clone.
pre_existing_set = target._included_params.get(source, set())
target._included_params[source] = set.union(pre_existing_set, params_to_copy)
for param_name in params_to_copy:
setattr(target, param_name, getattr(source, param_name))
def add_clone_method(self, target):
def clone(source, cls=None, **kwargs):
"""
This function is copied from luigi `clone` to modify for our use case.
https://github.com/spotify/luigi/blob/b35e8bc1a7f8bc28a1291c9162b90f5628813175/luigi/task.py#L509
Creates a new instance of a task from an existing instance where some of the args have changed.
Additional feature:
Warnings are shown if a parameter(s) is passed in `cls` which was not
included from `cls` at the first place.
"""
cls = cls or source.__class__
# Get all common params
params_to_copy = set(p[0] for p in cls.get_params() if hasattr(source, p[0]))
explicit_params = set(kwargs.keys())
implicit_params = params_to_copy - explicit_params
non_included_implicit_params = implicit_params - source._included_params.get(cls, set())
non_included_implicit_params = set(p for p in non_included_implicit_params if not getattr(getattr(cls, p), 'suppress_clone_warning', False))
if non_included_implicit_params:
logger.warn("Parameters [%s] are not included from task: [%s] but being assigned implicitly.", non_included_implicit_params, cls.__name__)
kwargs.update({param: getattr(source, param) for param in implicit_params})
return cls(**kwargs)
target.clone = clone
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment