Created
April 5, 2019 05:36
-
-
Save TMaYaD/1c5e0898179699bdc1885ff05cbea552 to your computer and use it in GitHub Desktop.
luigi include_params_from
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
_all_params_ = object() | |
class include_params_from(object): | |
""" | |
Task parameter copy. | |
It also takes two optional parameters: `only` and `omit`. | |
If `only` is provided, only the parameters given in `only` will be copied. | |
If `omit` is provided, parameters given in `omit` will not be copied. | |
Usage: | |
.. code-block:: python | |
class TestTask(SawyerTask): | |
foo = Parameter() | |
bar = Parameter() | |
@include_params_from(TestTask, omit=['foo']): | |
class MyFirstTask(luigi.Task): | |
''' | |
This will contain all params of `TestTask` except 'foo'. | |
''' | |
def run(self): | |
print self.bar # this will be defined | |
print self.foo # this will raise error | |
# ... | |
@include_params_from(TestTask, only=['foo']): | |
class MySecondTask(luigi.Task): | |
''' | |
This will copy only 'foo' from `TestTask`. | |
''' | |
def run(self): | |
print self.foo # this will be defined | |
print self.bar # this will raise error. | |
# ... | |
""" | |
def __init__(self, *sources, only=_all_params_, omit=[]): | |
if not sources: | |
raise TypeError("sources cannot be empty") | |
def params_to_include_from(task): | |
task_params = set(p[0] for p in task.get_params()) | |
if only is not _all_params_: | |
task_params = set.intersection(task_params, set(only)) | |
return task_params - set(omit) | |
super().__init__() | |
self.params_to_include = {task: params_to_include_from(task) for task in sources} | |
def __call__(self, target): | |
""" | |
Returns `target` with all the parameters which are to be | |
included from other tasks. | |
""" | |
self.copy_params(target) | |
self.add_clone_method(target) | |
return target | |
def copy_params(self, target): | |
existing_params = set(p[0] for p in target.get_params()) | |
# Get all parameter objects from each of the underlying tasks | |
for source, params_to_include in self.params_to_include.items(): | |
params_to_copy = params_to_include - existing_params | |
# Store the set of params for future reference, ex., by clone. | |
pre_existing_set = target._included_params.get(source, set()) | |
target._included_params[source] = set.union(pre_existing_set, params_to_copy) | |
for param_name in params_to_copy: | |
setattr(target, param_name, getattr(source, param_name)) | |
def add_clone_method(self, target): | |
def clone(source, cls=None, **kwargs): | |
""" | |
This function is copied from luigi `clone` to modify for our use case. | |
https://github.com/spotify/luigi/blob/b35e8bc1a7f8bc28a1291c9162b90f5628813175/luigi/task.py#L509 | |
Creates a new instance of a task from an existing instance where some of the args have changed. | |
Additional feature: | |
Warnings are shown if a parameter(s) is passed in `cls` which was not | |
included from `cls` at the first place. | |
""" | |
cls = cls or source.__class__ | |
# Get all common params | |
params_to_copy = set(p[0] for p in cls.get_params() if hasattr(source, p[0])) | |
explicit_params = set(kwargs.keys()) | |
implicit_params = params_to_copy - explicit_params | |
non_included_implicit_params = implicit_params - source._included_params.get(cls, set()) | |
non_included_implicit_params = set(p for p in non_included_implicit_params if not getattr(getattr(cls, p), 'suppress_clone_warning', False)) | |
if non_included_implicit_params: | |
logger.warn("Parameters [%s] are not included from task: [%s] but being assigned implicitly.", non_included_implicit_params, cls.__name__) | |
kwargs.update({param: getattr(source, param) for param in implicit_params}) | |
return cls(**kwargs) | |
target.clone = clone |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment