Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lalinsky/a49b34994a850b5e647aa61f9095e1f9 to your computer and use it in GitHub Desktop.
Save lalinsky/a49b34994a850b5e647aa61f9095e1f9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "DecisionTreeClassificationModel to JSON"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": "from pyspark.ml.linalg import Vectors\nfrom pyspark.ml.feature import StringIndexer\nfrom pyspark.ml.classification import DecisionTreeClassifier\n\ndf = spark.createDataFrame([\n (1.0, Vectors.dense(1.0)),\n (0.0, Vectors.sparse(1, [], []))], [\"label\", \"features\"])\n\nstringIndexer = StringIndexer(inputCol=\"label\", outputCol=\"indexed\")\nsi_model = stringIndexer.fit(df)\ntd = si_model.transform(df)\n\ndt = DecisionTreeClassifier(maxDepth=2, labelCol=\"indexed\")\nmodel = dt.fit(td)",
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": "print model.toDebugString",
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": "DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4c7db90739bec3e34a3c) of depth 1 with 3 nodes\n If (feature 0 <= 0.0)\n Predict: 0.0\n Else (feature 0 > 0.0)\n Predict: 1.0\n\n"
}
],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": "def convert_node(node, features, rules, attribute=None):\n if node.getClass().getName() == 'org.apache.spark.ml.tree.LeafNode':\n return {\n 'attribute': attribute,\n 'rules': rules,\n 'prediction': node.prediction(),\n 'good': 0,\n 'bad': 0\n }\n\n split = node.split()\n original_attribute = attribute\n attribute = features[split.featureIndex()]\n\n if split.getClass().getName() == 'org.apache.spark.ml.tree.ContinuousSplit':\n threshold = split.threshold()\n left_rules = '{}<={}'.format(attribute, threshold)\n right_rules = '{}>{}'.format(attribute, threshold)\n elif split.getClass().getName() == 'org.apache.spark.ml.tree.CategoricalSplit':\n categories = split.leftCategories().mkString(\"{\", \",\", \"}\")\n left_rules = '{} in {}'.format(attribute, categories)\n right_rules = '{} not in {}'.format(attribute, categories)\n else:\n raise ValueError('unknown split class')\n \n children = [\n convert_node(node.leftChild(), features, left_rules, attribute),\n convert_node(node.rightChild(), features, right_rules, attribute)\n ]\n\n return {\n 'attribute_id': split.featureIndex(),\n 'attribute': original_attribute,\n 'threshold': threshold,\n 'prediction': node.prediction(),\n 'good': 0,\n 'bad': 0,\n 'rules': rules,\n 'children': children\n }\n\ndef convert_model(model, features):\n return convert_node(model._call_java('rootNode'), features, '')\n\nimport pprint\npprint.pprint(convert_model(model, ['feature1']))",
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": "{'attribute': None,\n 'attribute_id': 0,\n 'bad': 0,\n 'children': [{'attribute': 'feature1',\n 'bad': 0,\n 'good': 0,\n 'prediction': 0.0,\n 'rules': 'feature1<=0.0'},\n {'attribute': 'feature1',\n 'bad': 0,\n 'good': 0,\n 'prediction': 1.0,\n 'rules': 'feature1>0.0'}],\n 'good': 0,\n 'prediction': 0.0,\n 'rules': '',\n 'threshold': 0.0}\n"
}
],
"prompt_number": 3
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment