Skip to content

Instantly share code, notes, and snippets.

@krishnanraman
Last active February 10, 2018 00:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krishnanraman/d1dfba76b699bd8773ddb119193c9357 to your computer and use it in GitHub Desktop.
Save krishnanraman/d1dfba76b699bd8773ddb119193c9357 to your computer and use it in GitHub Desktop.
Get all leaves of the DecisionTree ( then construct spline thru leafnodes to build f(x)=>y )
// Exiting paste mode, now interpreting.
id = 8, isLeaf = true, predict = 0.0 (prob = -1.0), impurity = 0.0, split = None, stats = None
id = 9, isLeaf = true, predict = 1.4736842105263157 (prob = -1.0), impurity = 0.2493074792243767, split = None, stats = None
id = 10, isLeaf = true, predict = 3.0 (prob = -1.0), impurity = 0.16666666666666666, split = None, stats = None
id = 11, isLeaf = true, predict = 4.1 (prob = -1.0), impurity = 0.09000000000000057, split = None, stats = None
id = 12, isLeaf = true, predict = 5.0 (prob = -1.0), impurity = 0.0, split = None, stats = None
id = 13, isLeaf = true, predict = 6.444444444444445 (prob = -1.0), impurity = 0.2469135802469143, split = None, stats = None
id = 14, isLeaf = true, predict = 7.923076923076923 (prob = -1.0), impurity = 0.2248520710059158, split = None, stats = None
id = 15, isLeaf = true, predict = 9.0 (prob = -1.0), impurity = 0.0, split = None, stats = None
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.tree.configuration._
import org.apache.spark.mllib.tree._
import org.apache.spark.mllib.tree.model.Node
def getLeaves(nl:List[Node]):List[Node] = {
var res = nl.map{ n=>
val l:List[Node] = if (n.leftNode.isDefined) List(n.leftNode.get) else List()
val r:List[Node] = if (n.rightNode.isDefined) List(n.rightNode.get) else List()
l ++ r ++ List(n)
}.flatten
.distinct
if (res.size != nl.size) res = getLeaves(res)
res.filter{ x=> x.isLeaf }
}
val data = (1 until 100).toList
.map{ x=> (x,x/10) }
.map{ case (feature:Int, label:Int) =>
new LabeledPoint(label, new DenseVector(Array(feature)))
}
val strategy = Strategy.defaultStrategy(Algo.Regression)
strategy.setMaxDepth(3)
val tree = new DecisionTree(strategy)
val model = tree.run(sc.makeRDD(data))
getLeaves(List(root)).foreach(println)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment