Skip to content

Instantly share code, notes, and snippets.

@schaunwheeler
Created April 13, 2019 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schaunwheeler/f06b527112b8230870a3e99b9ba28d09 to your computer and use it in GitHub Desktop.
Save schaunwheeler/f06b527112b8230870a3e99b9ba28d09 to your computer and use it in GitHub Desktop.
Funtion to dump trained Scikit-Learn RandomForestRegressor to JSON
from json import dumps
def rfr_to_json(rfr_object, feature_list, json_filepath=None):
'''
Function to convert a scikit-learn RandomForestRegressor object to JSON.
'''
output_dict = dict()
output_dict['name'] = 'rf_regression_pipeline'
output_dict['transformer'] = dict()
output_dict['transformer']['stages'] = list()
output_dict['transformer']['stages'].append(dict())
output_dict['transformer']['stages'][0]['task'] = 'extract-features'
output_dict['transformer']['stages'][0]['features'] = list()
for i, c in enumerate(feature_list):
feature_dict = {
'type': 'double',
'name': c,
'order': i,
}
output_dict['transformer']['stages'][0]['features'].append(feature_dict)
output_dict['transformer']['stages'].append(dict())
output_dict['transformer']['stages'][1]['outputType'] = 'double'
output_dict['transformer']['stages'][1]['pythonObject'] = 'sklearn.ensemble.RandomForestRegressor'
output_dict['transformer']['stages'][1]['nTrees'] = len(rfr_object.estimators_)
output_dict['transformer']['stages'][1]['trees'] = list()
for i, model in enumerate(rfr_object.estimators_):
tree_dict = dict(
i=i,
tree_undefined=_tree.TREE_UNDEFINED,
features=model.tree_.feature.tolist(),
thresholds=model.tree_.threshold.tolist(),
children_left=model.tree_.children_left.tolist(),
children_right=model.tree_.children_right.tolist(),
values=[val[0][0] for val in model.tree_.value]
)
output_dict['transformer']['stages'][1]['trees'].append(tree_dict)
if json_filepath is not None:
with open(json_filepath, 'w') as json_file:
dump(output_dict, json_file)
else:
print(dumps(output_dict))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment