Skip to content

Instantly share code, notes, and snippets.

@krishnanraman
Created July 12, 2016 02:35
Show Gist options
  • Save krishnanraman/a77b4e51afdb7190758390d383a258d6 to your computer and use it in GitHub Desktop.
Save krishnanraman/a77b4e51afdb7190758390d383a258d6 to your computer and use it in GitHub Desktop.
Visualize Decision Tree (especially regression trees ) in html5 using the canvas api
object ClusterTree {
/*
Draw a decision tree in html5 using the canvas api
Returns a valid html5 string, that can be persisted to some foo.html
*/
/*
define vertex V & edge E
*/
class Rect(val x:Int, val y:Int, val w:Int, val h:Int, val text:String)
case class V(override val x:Int, override val y:Int, override val w:Int, override val h:Int, override val text:String) extends Rect(x,y,w,h,text)
case class E(override val x:Int, override val y:Int, override val w:Int, override val h:Int, override val text:String) extends Rect(x,y,w,h,text)
case class TreeConfig(leafWidth:Int = 200,spacerW:Int = 20, spacerH:Int = 60, leafHeight:Int = 40, leafText:String = "CVR = ")
def to3d(a:Double) = (a*1000).toInt/1000.0
def draw(model:DecisionTreeModel,
featureIndexToName: Map[Int, String],
featureIndexToHeader: Map[Int, String],
leafCountM:Map[String, Long] = Map(), // how many items land on which leaf id
config:TreeConfig = TreeConfig()):String = {
// imports
import collection.mutable.{HashMap, Queue}
// config
val leafWidth = config.leafWidth
val spacerW = config.spacerW
val spacerH = config.spacerH
val leafHeight = config.leafHeight
val predictionText = config.leafText
// get a list of leaf nodes
val allNodes:List[MyNodeStats] = nodeStats(model)
val allLeaf = allNodes.filter { x => x.isLeaf }
val lookupNode:Map[Int, MyNodeStats] = allNodes.map{ x=> (x.node,x) }.toMap
// fix
val leafCount:Map[Int, Long] = {
if (leafCountM.isEmpty) Map.empty
else {
leafCountM.map{ kv =>
val (k,v) = kv
val a = k.indexOf("ID: ")
val b = k.indexOf(",")
(k.slice(a + 4, b).toInt, v)
}
}
}
// compute tree dimensions
val width = allLeaf.size * (leafWidth + spacerW)
val height = (1+model.depth) * (leafHeight + spacerH)
val nodeStack:Queue[MyNodeStats] = new Queue()
allLeaf.sortBy{ x => x.node }.foreach{ x => nodeStack.enqueue(x) }
// a repo of nodes
val nodeMap:HashMap[Int, Rect] = new HashMap()
// make a drawstack for html5 generation
val drawStack:Queue[Rect] = new Queue()
// process node stack until empty, populating drawstack as you go
var currLeafX = 0
while(!nodeStack.isEmpty) {
val n = nodeStack.dequeue
// leaves deserve special treatment. if you are a leaf, your x & y can be uniquely determined
if (n.isLeaf) {
val text:String = {
// look up in leafcount
val count = leafCount.getOrElse(n.node, -1)
predictionText +
to3d(n.me.predict.predict) +
{if (count != -1) ", Keywords: " + count else ""}
}
val v = V(currLeafX, height - leafHeight, leafWidth, leafHeight, text)
currLeafX += leafWidth + spacerW
drawStack.enqueue(v) // vertex needs to be drawn
nodeMap += (n.node -> v) // populate nodemap with this node for lookup
if (n.parent != -1) {
val parent = lookupNode(n.parent)
if (!nodeStack.contains(parent)) nodeStack.enqueue(parent)
}
} else {
val split:Split = n.me.split.get
val column = featureIndexToHeader(split.feature)
val columnlabel = featureIndexToName(split.feature)
val name = if (column == columnlabel) column else column + ":" + columnlabel
val text = {
if (split.featureType == FeatureType.Categorical) {
name + " is in " + split.categories.toString + " ?"
} else { // feature type must be continuous
if (split.threshold == 0.0) {
//name + " is OFF ?"
if (column != columnlabel)
column + " != " + columnlabel + " ?"
else
name + " <= " + to3d(split.threshold) + " ?"
} else {
name + " <= " + to3d(split.threshold) + " ?"
}
}
}
// ADD PARENT NODE
if (n.parent != -1) {
val parent = lookupNode(n.parent)
if (!nodeStack.contains(parent)) nodeStack.enqueue(parent)
}
val leftOpt = nodeMap.get(n.left)
val rightOpt = nodeMap.get(n.right)
// get the leaf nodes & average if balanced
if (leftOpt.isDefined && rightOpt.isDefined) {
// ADD NODE
val left:Rect = leftOpt.get
val right:Rect = rightOpt.get
val x = left.x + (right.x - left.x)/2
val y = left.y - spacerH - leafHeight
val v = V(x,y, leafWidth, leafHeight, text)
drawStack.enqueue(v) // vertex needs to be drawn
nodeMap += (n.node -> v) // populate nodemap with this node for lookup
// ADD EDGES
val mex = x + leafWidth/2
val mey = y + leafHeight/2
val a = left.x + leafWidth/2
val b = left.y + leafHeight/2
val c = right.x + leafWidth/2
val d = right.y + leafHeight/2
val e1 = E(mex, mey, a,b, "Y")
val e2 = E(mex, mey, c,d, "N")
drawStack.enqueue(e1)
drawStack.enqueue(e2)
} else if (leftOpt.isDefined) {
// only left defined
// ADD NODE
val left:Rect = leftOpt.get
val x = left.x
val y = left.y
val v = V(x,y, leafWidth, leafHeight, text)
drawStack.enqueue(v) // vertex needs to be drawn
nodeMap += (n.node -> v) // populate nodemap with this node for lookup
// ADD EDGES
val mex = x + leafWidth/2
val mey = y + leafHeight/2
val a = left.x + leafWidth/2
val b = left.y + leafHeight/2
val e1 = E(mex, mey, a,b, "Y")
drawStack.enqueue(e1)
} else {
// only right defined
// ADD NODE
val right:Rect = rightOpt.get
val x = right.x
val y = right.y
val v = V(x,y, leafWidth, leafHeight, text)
drawStack.enqueue(v) // vertex needs to be drawn
nodeMap += (n.node -> v) // populate nodemap with this node for lookup
// ADD EDGES
val mex = x + leafWidth/2
val mey = y + leafHeight/2
val c = right.x + leafWidth/2
val d = right.y + leafHeight/2
val e2 = E(mex, mey, c,d, "N")
drawStack.enqueue(e2)
}
}
}
val sbuf = new StringBuffer()
val header =
"""
<!DOCTYPE HTML>
<html>
<head>
<style>
body {
margin: 0px;
padding: 0px;
}
</style>
</head>
<body>
<canvas id="myCanvas" width=""" + "\"" + width + "\"" + "height=" + "\""+ height + "\"" + """></canvas>
<script>
var canvas = document.getElementById('myCanvas');
var context = canvas.getContext('2d');
context.font = '10pt Arial';
"""
val footer =
"""
</script>
</body>
</html>
"""
sbuf.append(header)
val (edges, vertices) = drawStack.partition{ x:Rect => x.isInstanceOf[E] }
edges.foreach{ e:Rect =>
val str = """context.beginPath();
context.moveTo(""" + e.x + "," + e.y + """);
context.lineTo(""" + e.w + "," + e.h + """);
context.stroke();
context.font = '10pt Arial';
context.fillText('""" + e.text + "', " + (e.x + e.w)/2 + ", " + (e.y + e.h)/2 + ");"
sbuf.append(str)
}
vertices.foreach{ v:Rect =>
val str = """context.beginPath();
context.rect(""" + v.x + "," + v.y + "," + v.w + "," + v.h + """);
context.fillStyle = 'yellow';
context.fill();
context.lineWidth = 1;
context.strokeStyle = 'black';
context.stroke();
context.fillStyle = 'blue';
context.fillText('""" + v.text + "', " + (v.x + 2) + ", " + (v.y + v.h/2) + ");"
sbuf.append(str)
}
sbuf.append(footer)
sbuf.toString
}
def nodeStats(model:DecisionTreeModel):List[MyNodeStats] = {
nodeStats(traverse(model.topNode))
}
/**
A recursive routine to traverse the decision tree given the root, build a sorted list of nodes
*/
def traverse(x:Node):List[Node] = {
val leftOpt = x.leftNode
val rightOpt = x.rightNode
val leftNodes = if (leftOpt.isDefined) traverse(leftOpt.get) else List[Node]()
val rightNodes = if (rightOpt.isDefined) traverse(rightOpt.get) else List[Node]()
(List(x) ++ leftNodes ++ rightNodes).sortBy{ x=> x.id }
}
/**
List[Node] => List[MyNodeStats]
*/
def nodeStats(nodes:List[Node]):List[MyNodeStats] = {
// find parents of all nodes
val prev:List[MyNodeStats] = nodes.map{
// node:Int, left:Int, right:Int, isLeaf:Boolean, parent:Int = -1)
n:Node => MyNodeStats( n, n.id,
if (n.leftNode.isDefined) n.leftNode.get.id else -1,
if (n.rightNode.isDefined) n.rightNode.get.id else -1,
n.isLeaf)
}
prev.map{
ns:MyNodeStats =>
val me = ns.me
val l = prev.filter{ lns:MyNodeStats => lns.left == me.id }
val r = prev.filter{ rns:MyNodeStats => rns.right == me.id }
MyNodeStats(me, ns.node, ns.left, ns.right, ns.isLeaf,
{if(l.size > 0) l.head.node else if (r.size > 0) r.head.node else -1}
)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment