Last active
August 29, 2015 14:27
-
-
Save hcarvalhoalves/cf8641bca3ec086c1e0a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
## same boilerplate as in common_models | |
class Stage(object): | |
def __init__(self, **kwargs): | |
self.constants = kwargs | |
def __or__(self, other): | |
return Pipeline(self, other) | |
def __add__(self, other): | |
return Combinator(self, other) | |
def do(self, *args, **kwargs): | |
raise NotImplementedError | |
class Combinator(object): | |
def __init__(self, *stages): | |
self.stages = stages | |
def __or__(self, other): | |
return Pipeline(self, other) | |
def __add__(self, other): | |
return Combinator(*(list(self.stages) + [other])) | |
class Pipeline(object): | |
def __init__(self, *stages): | |
self.stages = stages | |
def __or__(self, other): | |
return Pipeline(self, other) | |
def __add__(self, other): | |
return Combinator(self, other) | |
def run(stage, **initial): | |
def recur(stage, prev): | |
if isinstance(stage, Combinator): | |
return map(lambda s: recur(s, prev), stage.stages) | |
if isinstance(stage, Pipeline): | |
return reduce(lambda stack, stage: recur(stage, stack), | |
stage.stages, prev) | |
if isinstance(stage, Stage): | |
if isinstance(prev, collections.Mapping): | |
return stage.do(**prev) | |
return stage.do(*prev) | |
raise NotImplementedError(type(stage)) | |
return recur(stage, initial) | |
## decorator for lazy me ,,_o.O_,, | |
def stage(f): | |
class _S(Stage): | |
def __repr__(self): | |
return "{}({})".format(f.func_name, repr(self.constants)) | |
do = f | |
return _S | |
## test composition using operators for the win | |
@stage | |
def operate(self, y=0): | |
return {'y': self.constants['x'] + y} | |
@stage | |
def total(self, *args): | |
return {'total': sum([a[self.constants['key']] for a in args])} | |
def test_pipe(): | |
pipeline = operate(x=1) | operate(x=2) | |
assert run(pipeline) == {'y': 3} | |
another_pipeline = operate(x=3) | operate(x=4) | |
all_pipes = pipeline | another_pipeline | |
assert run(all_pipes) == {'y': 10} | |
def test_comb(): | |
comb = operate(x=1) + operate(x=2) + operate(x=3) | |
assert run(comb) == [{'y': 1}, {'y': 2}, {'y': 3}] | |
summary = operate(x=1) | comb | total(key='y') | |
assert run(summary) == {'total': 9} | |
## mimmick an imperative API for great success | |
@stage | |
def train_X(self): | |
return {'dataframe': [['foo', 1], ['bar', 2]]} | |
@stage | |
def train_y(self): | |
return {'dataframe': [True, False]} | |
@stage | |
def actual_X(self): | |
return {'dataframe': [['foo', 3], ['baz', 0]]} | |
@stage | |
def learner(self, X, y): | |
return {'rules': {fst: lst | |
for (fst, _), lst in zip(X['dataframe'], y['dataframe'])}} | |
@stage | |
def predictor(self, X, classifier): | |
missing = self.constants['missing'] | |
return {'dataframe': [classifier['rules'].get(fst, missing) | |
for (fst, _) in X['dataframe']]} | |
@stage | |
def wrapper(self): | |
return self.constants | |
class ShittyClassifier(object): | |
def __init__(self, missing=None): | |
self.missing = missing | |
def fit(self, X, y): | |
self.classifier = (X + y) | learner() # WTF? | |
return self.classifier | |
def predict(self, X): | |
return (X + self.classifier) | predictor(missing=self.missing) | |
def load(self, **kwargs): | |
self.classifier = wrapper(**kwargs) # Double-WTF? | |
## thanks to voodoo, everything lazy until `run` | |
def test_scikit_like_api(): | |
sc = ShittyClassifier(missing='NaNaNaN') | |
X, y = train_X(), train_y() | |
classifier = sc.fit(X, y) | |
assert run(classifier) == {'rules': {'foo': True, 'bar': False}} | |
test_X = actual_X() | |
final = sc.predict(test_X) | |
assert run(final) == {'dataframe': [True, 'NaNaNaN']} | |
computed_rules = run(classifier) | |
other_sc = ShittyClassifier(missing='Batman!') | |
other_sc.load(**computed_rules) | |
more = other_sc.predict(test_X) | |
assert run(more) == {'dataframe': [True, 'Batman!']} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment