Created
July 12, 2016 02:35
-
-
Save krishnanraman/a77b4e51afdb7190758390d383a258d6 to your computer and use it in GitHub Desktop.
Visualize Decision Tree (especially regression trees ) in html5 using the canvas api
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object ClusterTree { | |
/* | |
Draw a decision tree in html5 using the canvas api | |
Returns a valid html5 string, that can be persisted to some foo.html | |
*/ | |
/* | |
define vertex V & edge E | |
*/ | |
class Rect(val x:Int, val y:Int, val w:Int, val h:Int, val text:String) | |
case class V(override val x:Int, override val y:Int, override val w:Int, override val h:Int, override val text:String) extends Rect(x,y,w,h,text) | |
case class E(override val x:Int, override val y:Int, override val w:Int, override val h:Int, override val text:String) extends Rect(x,y,w,h,text) | |
case class TreeConfig(leafWidth:Int = 200,spacerW:Int = 20, spacerH:Int = 60, leafHeight:Int = 40, leafText:String = "CVR = ") | |
def to3d(a:Double) = (a*1000).toInt/1000.0 | |
def draw(model:DecisionTreeModel, | |
featureIndexToName: Map[Int, String], | |
featureIndexToHeader: Map[Int, String], | |
leafCountM:Map[String, Long] = Map(), // how many items land on which leaf id | |
config:TreeConfig = TreeConfig()):String = { | |
// imports | |
import collection.mutable.{HashMap, Queue} | |
// config | |
val leafWidth = config.leafWidth | |
val spacerW = config.spacerW | |
val spacerH = config.spacerH | |
val leafHeight = config.leafHeight | |
val predictionText = config.leafText | |
// get a list of leaf nodes | |
val allNodes:List[MyNodeStats] = nodeStats(model) | |
val allLeaf = allNodes.filter { x => x.isLeaf } | |
val lookupNode:Map[Int, MyNodeStats] = allNodes.map{ x=> (x.node,x) }.toMap | |
// fix | |
val leafCount:Map[Int, Long] = { | |
if (leafCountM.isEmpty) Map.empty | |
else { | |
leafCountM.map{ kv => | |
val (k,v) = kv | |
val a = k.indexOf("ID: ") | |
val b = k.indexOf(",") | |
(k.slice(a + 4, b).toInt, v) | |
} | |
} | |
} | |
// compute tree dimensions | |
val width = allLeaf.size * (leafWidth + spacerW) | |
val height = (1+model.depth) * (leafHeight + spacerH) | |
val nodeStack:Queue[MyNodeStats] = new Queue() | |
allLeaf.sortBy{ x => x.node }.foreach{ x => nodeStack.enqueue(x) } | |
// a repo of nodes | |
val nodeMap:HashMap[Int, Rect] = new HashMap() | |
// make a drawstack for html5 generation | |
val drawStack:Queue[Rect] = new Queue() | |
// process node stack until empty, populating drawstack as you go | |
var currLeafX = 0 | |
while(!nodeStack.isEmpty) { | |
val n = nodeStack.dequeue | |
// leaves deserve special treatment. if you are a leaf, your x & y can be uniquely determined | |
if (n.isLeaf) { | |
val text:String = { | |
// look up in leafcount | |
val count = leafCount.getOrElse(n.node, -1) | |
predictionText + | |
to3d(n.me.predict.predict) + | |
{if (count != -1) ", Keywords: " + count else ""} | |
} | |
val v = V(currLeafX, height - leafHeight, leafWidth, leafHeight, text) | |
currLeafX += leafWidth + spacerW | |
drawStack.enqueue(v) // vertex needs to be drawn | |
nodeMap += (n.node -> v) // populate nodemap with this node for lookup | |
if (n.parent != -1) { | |
val parent = lookupNode(n.parent) | |
if (!nodeStack.contains(parent)) nodeStack.enqueue(parent) | |
} | |
} else { | |
val split:Split = n.me.split.get | |
val column = featureIndexToHeader(split.feature) | |
val columnlabel = featureIndexToName(split.feature) | |
val name = if (column == columnlabel) column else column + ":" + columnlabel | |
val text = { | |
if (split.featureType == FeatureType.Categorical) { | |
name + " is in " + split.categories.toString + " ?" | |
} else { // feature type must be continuous | |
if (split.threshold == 0.0) { | |
//name + " is OFF ?" | |
if (column != columnlabel) | |
column + " != " + columnlabel + " ?" | |
else | |
name + " <= " + to3d(split.threshold) + " ?" | |
} else { | |
name + " <= " + to3d(split.threshold) + " ?" | |
} | |
} | |
} | |
// ADD PARENT NODE | |
if (n.parent != -1) { | |
val parent = lookupNode(n.parent) | |
if (!nodeStack.contains(parent)) nodeStack.enqueue(parent) | |
} | |
val leftOpt = nodeMap.get(n.left) | |
val rightOpt = nodeMap.get(n.right) | |
// get the leaf nodes & average if balanced | |
if (leftOpt.isDefined && rightOpt.isDefined) { | |
// ADD NODE | |
val left:Rect = leftOpt.get | |
val right:Rect = rightOpt.get | |
val x = left.x + (right.x - left.x)/2 | |
val y = left.y - spacerH - leafHeight | |
val v = V(x,y, leafWidth, leafHeight, text) | |
drawStack.enqueue(v) // vertex needs to be drawn | |
nodeMap += (n.node -> v) // populate nodemap with this node for lookup | |
// ADD EDGES | |
val mex = x + leafWidth/2 | |
val mey = y + leafHeight/2 | |
val a = left.x + leafWidth/2 | |
val b = left.y + leafHeight/2 | |
val c = right.x + leafWidth/2 | |
val d = right.y + leafHeight/2 | |
val e1 = E(mex, mey, a,b, "Y") | |
val e2 = E(mex, mey, c,d, "N") | |
drawStack.enqueue(e1) | |
drawStack.enqueue(e2) | |
} else if (leftOpt.isDefined) { | |
// only left defined | |
// ADD NODE | |
val left:Rect = leftOpt.get | |
val x = left.x | |
val y = left.y | |
val v = V(x,y, leafWidth, leafHeight, text) | |
drawStack.enqueue(v) // vertex needs to be drawn | |
nodeMap += (n.node -> v) // populate nodemap with this node for lookup | |
// ADD EDGES | |
val mex = x + leafWidth/2 | |
val mey = y + leafHeight/2 | |
val a = left.x + leafWidth/2 | |
val b = left.y + leafHeight/2 | |
val e1 = E(mex, mey, a,b, "Y") | |
drawStack.enqueue(e1) | |
} else { | |
// only right defined | |
// ADD NODE | |
val right:Rect = rightOpt.get | |
val x = right.x | |
val y = right.y | |
val v = V(x,y, leafWidth, leafHeight, text) | |
drawStack.enqueue(v) // vertex needs to be drawn | |
nodeMap += (n.node -> v) // populate nodemap with this node for lookup | |
// ADD EDGES | |
val mex = x + leafWidth/2 | |
val mey = y + leafHeight/2 | |
val c = right.x + leafWidth/2 | |
val d = right.y + leafHeight/2 | |
val e2 = E(mex, mey, c,d, "N") | |
drawStack.enqueue(e2) | |
} | |
} | |
} | |
val sbuf = new StringBuffer() | |
val header = | |
""" | |
<!DOCTYPE HTML> | |
<html> | |
<head> | |
<style> | |
body { | |
margin: 0px; | |
padding: 0px; | |
} | |
</style> | |
</head> | |
<body> | |
<canvas id="myCanvas" width=""" + "\"" + width + "\"" + "height=" + "\""+ height + "\"" + """></canvas> | |
<script> | |
var canvas = document.getElementById('myCanvas'); | |
var context = canvas.getContext('2d'); | |
context.font = '10pt Arial'; | |
""" | |
val footer = | |
""" | |
</script> | |
</body> | |
</html> | |
""" | |
sbuf.append(header) | |
val (edges, vertices) = drawStack.partition{ x:Rect => x.isInstanceOf[E] } | |
edges.foreach{ e:Rect => | |
val str = """context.beginPath(); | |
context.moveTo(""" + e.x + "," + e.y + """); | |
context.lineTo(""" + e.w + "," + e.h + """); | |
context.stroke(); | |
context.font = '10pt Arial'; | |
context.fillText('""" + e.text + "', " + (e.x + e.w)/2 + ", " + (e.y + e.h)/2 + ");" | |
sbuf.append(str) | |
} | |
vertices.foreach{ v:Rect => | |
val str = """context.beginPath(); | |
context.rect(""" + v.x + "," + v.y + "," + v.w + "," + v.h + """); | |
context.fillStyle = 'yellow'; | |
context.fill(); | |
context.lineWidth = 1; | |
context.strokeStyle = 'black'; | |
context.stroke(); | |
context.fillStyle = 'blue'; | |
context.fillText('""" + v.text + "', " + (v.x + 2) + ", " + (v.y + v.h/2) + ");" | |
sbuf.append(str) | |
} | |
sbuf.append(footer) | |
sbuf.toString | |
} | |
def nodeStats(model:DecisionTreeModel):List[MyNodeStats] = { | |
nodeStats(traverse(model.topNode)) | |
} | |
/** | |
A recursive routine to traverse the decision tree given the root, build a sorted list of nodes | |
*/ | |
def traverse(x:Node):List[Node] = { | |
val leftOpt = x.leftNode | |
val rightOpt = x.rightNode | |
val leftNodes = if (leftOpt.isDefined) traverse(leftOpt.get) else List[Node]() | |
val rightNodes = if (rightOpt.isDefined) traverse(rightOpt.get) else List[Node]() | |
(List(x) ++ leftNodes ++ rightNodes).sortBy{ x=> x.id } | |
} | |
/** | |
List[Node] => List[MyNodeStats] | |
*/ | |
def nodeStats(nodes:List[Node]):List[MyNodeStats] = { | |
// find parents of all nodes | |
val prev:List[MyNodeStats] = nodes.map{ | |
// node:Int, left:Int, right:Int, isLeaf:Boolean, parent:Int = -1) | |
n:Node => MyNodeStats( n, n.id, | |
if (n.leftNode.isDefined) n.leftNode.get.id else -1, | |
if (n.rightNode.isDefined) n.rightNode.get.id else -1, | |
n.isLeaf) | |
} | |
prev.map{ | |
ns:MyNodeStats => | |
val me = ns.me | |
val l = prev.filter{ lns:MyNodeStats => lns.left == me.id } | |
val r = prev.filter{ rns:MyNodeStats => rns.right == me.id } | |
MyNodeStats(me, ns.node, ns.left, ns.right, ns.isLeaf, | |
{if(l.size > 0) l.head.node else if (r.size > 0) r.head.node else -1} | |
) | |
} | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment