Skip to content

Instantly share code, notes, and snippets.

@danthedaniel
Last active July 10, 2017 07:39
Show Gist options
  • Save danthedaniel/32a750b314585a95934835036ee89129 to your computer and use it in GitHub Desktop.
Save danthedaniel/32a750b314585a95934835036ee89129 to your computer and use it in GitHub Desktop.
Reversible BNF in TPOT
import numpy as np
classifier_config_dict_light = {
# Classifiers
'sklearn.naive_bayes.GaussianNB': {
},
'sklearn.naive_bayes.BernoulliNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.],
'fit_prior': [True, False]
},
# Preprocesssors
'sklearn.preprocessing.Binarizer': {
'threshold': np.arange(0.0, 1.01, 0.05)
},
'sklearn.cluster.FeatureAgglomeration': {
'linkage': ['ward', 'complete', 'average'],
'affinity': ['euclidean', 'l1', 'l2', 'manhattan', 'cosine', 'precomputed']
},
# Selectors
'sklearn.feature_selection.VarianceThreshold': {
'threshold': np.arange(0.05, 1.01, 0.05)
}
}
import random
import re
import importlib
from collections import Iterable
from copy import deepcopy
from sklearn.pipeline import make_pipeline
from tpot.builtins import StackingEstimator
from config import classifier_config_dict_light
variable_re = re.compile(r'^\$(\S+)$')
def reverse(rules, start='S'):
"""Reverse a CFG to produce a random word in its language.
Parameters
----------
rules : dict (str -> set(tuple))
A set of rules defining a language.
start : str
The starting rule in the CFG.
Returns
-------
A parse tree for a word in the CFG.
"""
tree = []
rule = rules[start]
branch = random.choice(tuple(rule))
for atom in branch:
groups = re.findall(variable_re, str(atom))
if len(groups) == 0:
tree.append(str(atom))
else:
nested_tree = reverse(rules, start=groups[0])
tree.append(nested_tree)
return tree
def flatten(tree):
"""Flatten a tree into a single, flat list."""
for x in tree:
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
yield from flatten(x)
else:
yield x
def clean_tree(tree):
"""Clean up a CFG parse tree."""
flattened = list(flatten(tree))
return ''.join(flattened)
def is_estimator(model):
"""Determine if a class is a machine-learning estimator."""
return 'predict' in dir(model)
def import_module(module_path):
"""Return an imported class from a path."""
*path, class_name = module_path.split('.')
path = '.'.join(path)
model_class = getattr(importlib.import_module(path), class_name)
return model_class
def grammar_from_config(config, grammar_base):
"""Generate a TPOT grammar from a TPOT config."""
grammar = deepcopy(grammar_base)
ctx = {
'make_pipeline': make_pipeline,
'StackingEstimator': StackingEstimator
}
for module_path, parameters in config.items():
model_class = import_module(module_path)
model_name = model_class.__name__
param_names = []
# Add the model to our evaluation context
ctx[model_name] = model_class
for param_name, param_values in parameters.items():
# Add parameter name to the list of parameter names for the model
grammar_param_name = '{}__{}'.format(model_name, param_name)
param_names.append([param_name, '=', '${}'.format(grammar_param_name), ', '])
# Add parameter and its values to the grammar
grammar[grammar_param_name] = set()
for value in param_values:
# Wrap string parameters in quotes
if isinstance(value, str):
grammar[grammar_param_name].add(('"{}"'.format(value), ))
else:
grammar[grammar_param_name].add((value, ))
model_entry = {(model_name, '(', *list(flatten(param_names)), '), ')}
model_type = 'est' if is_estimator(model_class) else 'prep'
# Add model to the set of estimators or preprocessors
grammar[model_type] |= model_entry
return (grammar, ctx)
pipeline_base = {
'make_pipeline': {('make_pipeline', '(', '$pipeline', ')')},
'pipeline': {('$ops', '$est')},
'ops': {('$prep', '$ops')} | {('$est_transform', '$ops')} | {()},
'combine': {('make_union', '(', '$make_pipeline', ', ', '$make_pipeline', ')')},
'est': set(),
'prep': set(),
'est_transform': {('StackingEstimator', '(', 'estimator=', '$est', '), ')},
}
pipeline, ctx = grammar_from_config(classifier_config_dict_light, pipeline_base)
# print(pipeline)
individual = clean_tree(reverse(pipeline, 'make_pipeline'))
print(individual)
evaluated_pipeline = eval(individual, ctx)
print(evaluated_pipeline)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment