Skip to content

Instantly share code, notes, and snippets.

@timelyportfolio
Last active August 29, 2015 14:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save timelyportfolio/d49cb07923eff7a75886 to your computer and use it in GitHub Desktop.
Save timelyportfolio/d49cb07923eff7a75886 to your computer and use it in GitHub Desktop.
rCharts + d3.js interactive view of a rpart / partykit recursive partioning object

in reply to this tweet

Anyone made interactive (tooltips, zoomable, etc.) plots of CART analyses in R? @timelyportfolio @yaacovp @xieyihui @ramnath_vaidya?

— Andrew Menzel (@menzbasketball) September 3, 2014
<script async src="//platform.twitter.com/widgets.js" charset="utf-8"></script>

I put a quick rCharts example together forking this fine collapsible d3 tree with pan/zoom.

I have included the code for this particular example in this gist. Also this rCharts_rpart repo contains some experiments and the pieces needed for rCharts.

live example

If you care, this is a work in progress, so please, please provide feedback, comments, suggestions.

library(rpart)
library(partykit)
library(rCharts)
library(jsonlite)
#set up a little rpart as an example
rp <- rpart(
hp ~ cyl + disp + mpg + drat + wt + qsec + vs + am + gear + carb,
method = "anova",
data = mtcars,
control = rpart.control(minsplit = 4)
)
#convert it to partykit rpart so we can use structure
rpk <- as.party(rp)
#set up rCharts
#key is to define how to handle the data
rChartsRpart <- setRefClass(
"rChartsRpart",
contains = "Dimple",
methods = list(
initialize = function(){
callSuper();
},
getPayload = function (chartId) {
data = jsonlite::toJSON(
rapply(params$data$node,unclass,how="replace")
,auto_unbox = T
)
data = gsub( x=data, pattern = "kids", replacement="children")
data = gsub ( x=data, pattern = '"id":([0-9]*)', replacement = '"name":"node\\1"' )
chart = toChain(params$chart, "myChart")
controls_json = toJSON(params$controls)
controls = setNames(params$controls, NULL)
opts = toJSON2(params[!(names(params) %in% c("data", "chart",
"controls"))])
list(opts = opts, data = data, chart = chart, chartId = chartId,
controls = controls, controls_json = controls_json)
}
)
)
# now make a rChart with our rpart
rpRc <- rChartsRpart$new()
rpRc$setLib("http://timelyportfolio.github.io/rCharts_rpart")
rpRc$lib = "rpart_tree"
rpRc$LIB$name = "rpart_tree"
rpRc$setTemplate(
chartDiv = "<{{container}} id = '{{ chartId }}' class = '{{ lib }}' style = 'height:100%;width:100%;'></{{ container}}>"
)
rpRc$set(
data = rpk
, height = 800
, width = 800
)
rpRc
<!doctype HTML>
<meta charset = 'utf-8'>
<html>
<head>
<link rel='stylesheet' href='http://timelyportfolio.github.io/rCharts_rpart/css/treestyle.css'>
<script src='http://d3js.org/d3.v3.min.js' type='text/javascript'></script>
<style>
.rChart {
display: block;
margin-left: auto;
margin-right: auto;
width: 400px;
height: 400px;
}
</style>
</head>
<body >
<div id = 'chart2aa039cf291c' class = 'rpart_tree' style = 'height:400px;width:100%;display:inline-block;'></div>
<pre>
Model formula:
hp ~ cyl + disp + mpg + drat + wt + qsec + vs + am + gear + carb
Fitted party:
[1] root
| [2] cyl < 7
| | [3] mpg >= 21.45
| | | [4] disp < 87.05: 62.250 (n = 4, err = 140.8)
| | | [5] disp >= 87.05: 91.833 (n = 6, err = 1376.8)
| | [6] mpg < 21.45
| | | [7] qsec >= 15.98: 112.857 (n = 7, err = 306.9)
| | | [8] qsec < 15.98: 175.000 (n = 1, err = 0.0)
| [9] cyl >= 7
| | [10] drat < 3.18
| | | [11] mpg >= 12.8: 170.000 (n = 7, err = 1150.0)
| | | [12] mpg < 12.8: 210.000 (n = 2, err = 50.0)
| | [13] drat >= 3.18
| | | [14] carb < 6: 246.000 (n = 4, err = 582.0)
| | | [15] carb >= 6: 335.000 (n = 1, err = 0.0)
Number of inner nodes: 7
Number of terminal nodes: 8
</pre>
<script>
// Get JSON data
(function(opts){
var opts = {
"dom": "chart2aa039cf291c",
"width": 400,
"height": 400,
"xAxis": {
"type": "addCategoryAxis",
"showPercent": false
},
"yAxis": {
"type": "addMeasureAxis",
"showPercent": false
},
"zAxis": [],
"colorAxis": [],
"defaultColors": [],
"layers": [],
"legend": [],
"id": "chart2aa039cf291c"
};
var treeData = {"name":"node1","split":{"varid":2,"breaks":7,"index":[],"right":false,"prob":[1,0],"info":[]},"children":[{"name":"node2","split":{"varid":4,"breaks":21.45,"index":[2,1],"right":false,"prob":[1,0],"info":[]},"children":[{"name":"node3","split":{"varid":3,"breaks":87.05,"index":[],"right":false,"prob":[0,1],"info":[]},"children":[{"name":"node4","split":[],"children":[],"surrogates":[],"info":[]},{"name":"node5","split":[],"children":[],"surrogates":[],"info":[]}],"surrogates":[{"varid":4,"breaks":26.65,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":5,"breaks":4,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":6,"breaks":2.0375,"index":[],"right":false,"prob":[],"info":[]},{"varid":7,"breaks":19.95,"index":[],"right":false,"prob":[],"info":[]},{"varid":9,"breaks":0.5,"index":[2,1],"right":false,"prob":[],"info":[]}],"info":[]},{"name":"node6","split":{"varid":7,"breaks":15.98,"index":[2,1],"right":false,"prob":[1,0],"info":[]},"children":[{"name":"node7","split":[],"children":[],"surrogates":[],"info":[]},{"name":"node8","split":[],"children":[],"surrogates":[],"info":[]}],"surrogates":[],"info":[]}],"surrogates":[{"varid":2,"breaks":5,"index":[],"right":false,"prob":[],"info":[]},{"varid":3,"breaks":120.65,"index":[],"right":false,"prob":[],"info":[]},{"varid":6,"breaks":2.5425,"index":[],"right":false,"prob":[],"info":[]},{"varid":11,"breaks":3,"index":[],"right":false,"prob":[],"info":[]},{"varid":5,"breaks":3.655,"index":[2,1],"right":false,"prob":[],"info":[]}],"info":[]},{"name":"node9","split":{"varid":5,"breaks":3.18,"index":[],"right":false,"prob":[1,0],"info":[]},"children":[{"name":"node10","split":{"varid":4,"breaks":12.8,"index":[2,1],"right":false,"prob":[1,0],"info":[]},"children":[{"name":"node11","split":[],"children":[],"surrogates":[],"info":[]},{"name":"node12","split":[],"children":[],"surrogates":[],"info":[]}],"surrogates":[{"varid":3,"breaks":430,"index":[],"right":false,"prob":[],"info":[]},{"varid":6,"breaks":4.66,"index":[],"right":false,"prob":[],"info":[]},{"varid":11,"breaks":3.5,"index":[],"right":false,"prob":[],"info":[]},{"varid":5,"breaks":3.035,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":7,"breaks":17.71,"index":[],"right":false,"prob":[],"info":[]}],"info":[]},{"name":"node13","split":{"varid":11,"breaks":6,"index":[],"right":false,"prob":[1,0],"info":[]},"children":[{"name":"node14","split":[],"children":[],"surrogates":[],"info":[]},{"name":"node15","split":[],"children":[],"surrogates":[],"info":[]}],"surrogates":[],"info":[]}],"surrogates":[{"varid":7,"breaks":16.355,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":11,"breaks":3.5,"index":[],"right":false,"prob":[],"info":[]},{"varid":4,"breaks":15.1,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":9,"breaks":0.5,"index":[],"right":false,"prob":[],"info":[]},{"varid":10,"breaks":4,"index":[],"right":false,"prob":[],"info":[]}],"info":[]}],"surrogates":[{"varid":3,"breaks":266.9,"index":[],"right":false,"prob":[],"info":[]},{"varid":4,"breaks":17.55,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":6,"breaks":3.49,"index":[],"right":false,"prob":[],"info":[]},{"varid":5,"breaks":3.58,"index":[2,1],"right":false,"prob":[],"info":[]},{"varid":8,"breaks":0.5,"index":[2,1],"right":false,"prob":[],"info":[]}],"info":[]};
// Calculate total nodes, max label length
var totalNodes = 0;
var maxLabelLength = 0;
// variables for drag/drop
var selectedNode = null;
var draggingNode = null;
// panning variables
var panSpeed = 200;
var panBoundary = 20; // Within 20px from edges will pan when dragging.
// Misc. variables
var i = 0;
var duration = 750;
var root;
// define the baseSvg, attaching a class for styling and the zoomListener
var baseSvg = d3.select("#" + opts.id).append("svg")
.attr("width", "100%")
.attr("height", "100%")
.attr("class", "overlay")
// size of the diagram
var viewerWidth = d3.select("#" + opts.id)[0][0].getBoundingClientRect().width;
var viewerHeight = d3.select("#" + opts.id)[0][0].getBoundingClientRect().height;
var tree = d3.layout.tree()
.size([viewerHeight, viewerWidth]);
// define a d3 diagonal projection for use by the node paths later on.
var diagonal = d3.svg.diagonal()
.projection(function(d) {
return [d.y, d.x];
});
// A recursive helper function for performing some setup by walking through all nodes
function visit(parent, visitFn, childrenFn) {
if (!parent) return;
visitFn(parent);
var children = childrenFn(parent);
if (children) {
var count = children.length;
for (var i = 0; i < count; i++) {
visit(children[i], visitFn, childrenFn);
}
}
}
// Call visit function to establish maxLabelLength
visit(treeData, function(d) {
totalNodes++;
maxLabelLength = Math.max(d.name.length, maxLabelLength);
}, function(d) {
return d.children && d.children.length > 0 ? d.children : null;
});
// sort the tree according to the node names
function sortTree() {
tree.sort(function(a, b) {
return b.name.toLowerCase() < a.name.toLowerCase() ? 1 : -1;
});
}
// Sort the tree initially incase the JSON isn't in a sorted order.
sortTree();
// TODO: Pan function, can be better implemented.
function pan(domNode, direction) {
var speed = panSpeed;
if (panTimer) {
clearTimeout(panTimer);
translateCoords = d3.transform(svgGroup.attr("transform"));
if (direction == 'left' || direction == 'right') {
translateX = direction == 'left' ? translateCoords.translate[0] + speed : translateCoords.translate[0] - speed;
translateY = translateCoords.translate[1];
} else if (direction == 'up' || direction == 'down') {
translateX = translateCoords.translate[0];
translateY = direction == 'up' ? translateCoords.translate[1] + speed : translateCoords.translate[1] - speed;
}
scaleX = translateCoords.scale[0];
scaleY = translateCoords.scale[1];
scale = zoomListener.scale();
svgGroup.transition().attr("transform", "translate(" + translateX + "," + translateY + ")scale(" + scale + ")");
d3.select(domNode).select('g.node').attr("transform", "translate(" + translateX + "," + translateY + ")");
zoomListener.scale(zoomListener.scale());
zoomListener.translate([translateX, translateY]);
panTimer = setTimeout(function() {
pan(domNode, speed, direction);
}, 50);
}
}
// Define the zoom function for the zoomable tree
function zoom() {
svgGroup.attr("transform", "translate(" + d3.event.translate + ")scale(" + d3.event.scale + ")");
}
// define the zoomListener which calls the zoom function on the "zoom" event constrained within the scaleExtents
var zoomListener = d3.behavior.zoom().scaleExtent([0.1, 3]).on("zoom", zoom);
function initiateDrag(d, domNode) {
draggingNode = d;
d3.select(domNode).select('.ghostCircle').attr('pointer-events', 'none');
d3.selectAll('.ghostCircle').attr('class', 'ghostCircle show');
d3.select(domNode).attr('class', 'node activeDrag');
svgGroup.selectAll("g.node").sort(function(a, b) { // select the parent and sort the path's
if (a.id != draggingNode.id) return 1; // a is not the hovered element, send "a" to the back
else return -1; // a is the hovered element, bring "a" to the front
});
// if nodes has children, remove the links and nodes
if (nodes.length > 1) {
// remove link paths
links = tree.links(nodes);
nodePaths = svgGroup.selectAll("path.link")
.data(links, function(d) {
return d.target.id;
}).remove();
// remove child nodes
nodesExit = svgGroup.selectAll("g.node")
.data(nodes, function(d) {
return d.id;
}).filter(function(d, i) {
if (d.id == draggingNode.id) {
return false;
}
return true;
}).remove();
}
// remove parent link
parentLink = tree.links(tree.nodes(draggingNode.parent));
svgGroup.selectAll('path.link').filter(function(d, i) {
if (d.target.id == draggingNode.id) {
return true;
}
return false;
}).remove();
dragStarted = null;
}
baseSvg
.call(zoomListener);
/*
// Define the drag listeners for drag/drop behaviour of nodes.
dragListener = d3.behavior.drag()
.on("dragstart", function(d) {
if (d == root) {
return;
}
dragStarted = true;
nodes = tree.nodes(d);
d3.event.sourceEvent.stopPropagation();
// it's important that we suppress the mouseover event on the node being dragged. Otherwise it will absorb the mouseover event and the underlying node will not detect it d3.select(this).attr('pointer-events', 'none');
})
.on("drag", function(d) {
if (d == root) {
return;
}
if (dragStarted) {
domNode = this;
initiateDrag(d, domNode);
}
// get coords of mouseEvent relative to svg container to allow for panning
relCoords = d3.mouse(svgGroup[0][0]);
if (relCoords[0] < panBoundary) {
panTimer = true;
pan(this, 'left');
} else if (relCoords[0] > (viewerWidth - panBoundary)) {
panTimer = true;
pan(this, 'right');
} else if (relCoords[1] < panBoundary) {
panTimer = true;
pan(this, 'up');
} else if (relCoords[1] > (viewerHeight - panBoundary)) {
panTimer = true;
pan(this, 'down');
} else {
try {
clearTimeout(panTimer);
} catch (e) {
}
}
d.x0 += d3.event.dy;
d.y0 += d3.event.dx;
var node = d3.select(this);
node.attr("transform", "translate(" + d.y0 + "," + d.x0 + ")");
updateTempConnector();
}).on("dragend", function(d) {
if (d == root) {
return;
}
domNode = this;
if (selectedNode) {
// now remove the element from the parent, and insert it into the new elements children
var index = draggingNode.parent.children.indexOf(draggingNode);
if (index > -1) {
draggingNode.parent.children.splice(index, 1);
}
if (typeof selectedNode.children !== 'undefined' || typeof selectedNode._children !== 'undefined') {
if (typeof selectedNode.children !== 'undefined') {
selectedNode.children.push(draggingNode);
} else {
selectedNode._children.push(draggingNode);
}
} else {
selectedNode.children = [];
selectedNode.children.push(draggingNode);
}
// Make sure that the node being added to is expanded so user can see added node is correctly moved
expand(selectedNode);
sortTree();
endDrag();
} else {
endDrag();
}
});
function endDrag() {
selectedNode = null;
d3.selectAll('.ghostCircle').attr('class', 'ghostCircle');
d3.select(domNode).attr('class', 'node');
// now restore the mouseover event or we won't be able to drag a 2nd time
d3.select(domNode).select('.ghostCircle').attr('pointer-events', '');
updateTempConnector();
if (draggingNode !== null) {
update(root);
centerNode(draggingNode);
draggingNode = null;
}
}
*/
// Helper functions for collapsing and expanding nodes.
function collapse(d) {
if (d.children) {
d._children = d.children;
d._children.forEach(collapse);
d.children = null;
}
}
function expand(d) {
if (d._children) {
d.children = d._children;
d.children.forEach(expand);
d._children = null;
}
}
var overCircle = function(d) {
selectedNode = d;
updateTempConnector();
};
var outCircle = function(d) {
selectedNode = null;
updateTempConnector();
};
// Function to update the temporary connector indicating dragging affiliation
var updateTempConnector = function() {
var data = [];
if (draggingNode !== null && selectedNode !== null) {
// have to flip the source coordinates since we did this for the existing connectors on the original tree
data = [{
source: {
x: selectedNode.y0,
y: selectedNode.x0
},
target: {
x: draggingNode.y0,
y: draggingNode.x0
}
}];
}
var link = svgGroup.selectAll(".templink").data(data);
link.enter().append("path")
.attr("class", "templink")
.attr("d", d3.svg.diagonal())
.attr('pointer-events', 'none');
link.attr("d", d3.svg.diagonal());
link.exit().remove();
};
// Function to center node when clicked/dropped so node doesn't get lost when collapsing/moving with large amount of children.
function centerNode(source) {
scale = zoomListener.scale();
x = -source.y0;
y = -source.x0;
x = x * scale + viewerWidth / 2;
y = y * scale + viewerHeight / 2;
d3.select('g').transition()
.duration(duration)
.attr("transform", "translate(" + x + "," + y + ")scale(" + scale + ")");
zoomListener.scale(scale);
zoomListener.translate([x, y]);
}
// Toggle children function
function toggleChildren(d) {
if (d.children) {
d._children = d.children;
d.children = null;
} else if (d._children) {
d.children = d._children;
d._children = null;
}
return d;
}
// Toggle children on click.
function click(d) {
if (d3.event.defaultPrevented) return; // click suppressed
d = toggleChildren(d);
update(d);
centerNode(d);
}
function update(source) {
// Compute the new height, function counts total children of root node and sets tree height accordingly.
// This prevents the layout looking squashed when new nodes are made visible or looking sparse when nodes are removed
// This makes the layout more consistent.
var levelWidth = [1];
var childCount = function(level, n) {
if (n.children && n.children.length > 0) {
if (levelWidth.length <= level + 1) levelWidth.push(0);
levelWidth[level + 1] += n.children.length;
n.children.forEach(function(d) {
childCount(level + 1, d);
});
}
};
childCount(0, root);
var newHeight = d3.max(levelWidth) * 25; // 25 pixels per line
tree = tree.size([newHeight, viewerWidth]);
// Compute the new tree layout.
var nodes = tree.nodes(root).reverse(),
links = tree.links(nodes);
// Set widths between levels based on maxLabelLength.
nodes.forEach(function(d) {
d.y = (d.depth * (maxLabelLength * 10)); //maxLabelLength * 10px
// alternatively to keep a fixed scale one can set a fixed depth per level
// Normalize for fixed-depth by commenting out below line
// d.y = (d.depth * 500); //500px per level.
});
// Update the nodes…
node = svgGroup.selectAll("g.node")
.data(nodes, function(d) {
return d.id || (d.id = ++i);
});
// Enter any new nodes at the parent's previous position.
var nodeEnter = node.enter().append("g")
// .call(dragListener)
.attr("class", "node")
.attr("transform", function(d) {
return "translate(" + source.y0 + "," + source.x0 + ")";
})
.on('click', click);
nodeEnter.append("circle")
.attr('class', 'nodeCircle')
.attr("r", 0)
.style("fill", function(d) {
return d._children ? "lightsteelblue" : "#fff";
});
nodeEnter.append("text")
.attr("x", function(d) {
return d.children || d._children ? -10 : 10;
})
.attr("dy", ".35em")
.attr('class', 'nodeText')
.attr("text-anchor", function(d) {
return d.children || d._children ? "end" : "start";
})
.text(function(d) {
return d.name;
})
.style("fill-opacity", 0);
// phantom node to give us mouseover in a radius around it
nodeEnter.append("circle")
.attr('class', 'ghostCircle')
.attr("r", 30)
.attr("opacity", 0.2) // change this to zero to hide the target area
.style("fill", "red")
.attr('pointer-events', 'mouseover')
.on("mouseover", function(node) {
overCircle(node);
})
.on("mouseout", function(node) {
outCircle(node);
});
// Update the text to reflect whether node has children or not.
node.select('text')
.attr("x", function(d) {
return d.children || d._children ? -10 : 10;
})
.attr("text-anchor", function(d) {
return d.children || d._children ? "end" : "start";
})
.text(function(d) {
return d.name;
});
// Change the circle fill depending on whether it has children and is collapsed
node.select("circle.nodeCircle")
.attr("r", 4.5)
.style("fill", function(d) {
return d._children ? "lightsteelblue" : "#fff";
});
// Transition nodes to their new position.
var nodeUpdate = node.transition()
.duration(duration)
.attr("transform", function(d) {
return "translate(" + d.y + "," + d.x + ")";
});
// Fade the text in
nodeUpdate.select("text")
.style("fill-opacity", 1);
// Transition exiting nodes to the parent's new position.
var nodeExit = node.exit().transition()
.duration(duration)
.attr("transform", function(d) {
return "translate(" + source.y + "," + source.x + ")";
})
.remove();
nodeExit.select("circle")
.attr("r", 0);
nodeExit.select("text")
.style("fill-opacity", 0);
// Update the links…
var link = svgGroup.selectAll("path.link")
.data(links, function(d) {
return d.target.id;
});
// Enter any new links at the parent's previous position.
link.enter().insert("path", "g")
.attr("class", "link")
.attr("d", function(d) {
var o = {
x: source.x0,
y: source.y0
};
return diagonal({
source: o,
target: o
});
});
// Transition links to their new position.
link.transition()
.duration(duration)
.attr("d", diagonal);
// Transition exiting nodes to the parent's new position.
link.exit().transition()
.duration(duration)
.attr("d", function(d) {
var o = {
x: source.x,
y: source.y
};
return diagonal({
source: o,
target: o
});
})
.remove();
// Stash the old positions for transition.
nodes.forEach(function(d) {
d.x0 = d.x;
d.y0 = d.y;
});
}
// Append a group which holds all nodes and which the zoom Listener can act upon.
var svgGroup = baseSvg.append("g");
// Define the root
root = treeData;
root.x0 = viewerHeight / 2;
root.y0 = 0;
// Layout the tree initially and center on the root node.
update(root);
centerNode(root);
})();
</script>
<script></script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment