Skip to content

Instantly share code, notes, and snippets.

@himetani
Last active February 20, 2016 08:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save himetani/2b0a5bed69ddfa3c21ca to your computer and use it in GitHub Desktop.
Save himetani/2b0a5bed69ddfa3c21ca to your computer and use it in GitHub Desktop.
decision-tree
//生成した描画用のJSONデータをコピペ
var data = {
"name": "petalLength <= 2.45",
"children": [
{
"name": "setosa: 50"
},
{
"name": "petalWidth <= 1.75",
"children": [
{
"name": "petalLength <= 5.35",
"children": [
{
"name": "versicolor: 49"
},
{
"name": "virginica: 3"
},
{
"name": "virginica: 2"
}
]
},
{
"name": "versicolor: 1"
},
{
"name": "virginica: 45"
}
]
}
]
}
function drawTree(target, data) {
// Tree layout
var tree = d3.layout.tree().size([1000, 700]).separation(function() { return 1 });
// Funciton of creating a link
var diagonal = d3.svg.diagonal();
}
function children(d) {
return d["children"];
}
var width = 1000;
var height = 1000;
var tree = d3.layout.tree().size([400, 400]).children(children);
var nodes = tree.nodes(data);
var links = tree.links(nodes);
var svg = d3.select("#tree")
.append("svg")
.attr("width", width)
.attr("height", height)
.append("g")
.attr("transform", "translate(200,30)");
var node = svg.selectAll(".node")
.data(nodes)
.enter()
.append("g")
.attr("class", "node")
.attr("transform", function(d) {
return "translate(" + d.x + "," + d.y + ")"; });
node.append("circle")
.attr("r", 10)
.attr("y", function(d) { return d.y+4;})
.attr("stroke", "black")
.attr("fill", function(d) {
return d.children || d._children ? "black" : "white"
})
node.append("text")
.text(function(d) { return d.name})
.style("font-size", "10px")
.attr("x", -40)
.attr("y", function(d) {
return d.children || d._children ? -15 : 20
})
var diagonal = d3.svg.diagonal()
.projection(function(d) { return [d.x, d.y];});
svg.selectAll(".link")
.data(links)
.enter()
.append("path")
.attr("class", "link")
.attr("fill", "none")
.attr("stroke", "blue")
.attr("d",diagonal);
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script src="http://d3js.org/d3.v3.min.js" charset="utf-8"></script>
<script src="https://code.jquery.com/jquery-git2.min.js" charset="utf-8"></script>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.4/css/bootstrap.min.css">
<style>
.axis path,
.axis line {
fill: none;
stroke: black;
shape-rendering: crispEdges;
}
.axis text {
font-family: sans-serif;
}
</style>
</head>
<body>
<div class="container">
<div class="jumbotron">
<h1>Decision Tree</h1>
</div>
<div id="tree"></div>
</div>
<script type="text/javascript" src="./js/draw-tree.js"></script>
</body>
</html>
'use strict';
var CSV = require("comma-separated-values");
var fs = require("fs");
var _ = require("underscore");
var sys = require("sys");
//決定木のノード(コンストラクタ)
function DecisionNode(col, val, results, tb, fb, isEnd) {
var node = {};
node.col = col;
node.val = val;
node.results = results;
node.tb = tb;
node.fb = fb;
node.isEnd = isEnd;
return node;
}
//クラスとそれぞれの個数を計算する関数
function uniqueCounts(rows) {
var results = {};
if(rows.length !== 0) {
rows.forEach(function(row) {
var y = row[row.length-1];
if( results[y] >= 1) results[y]++;
else {
results[y] = 1;
}
});
}
return results;
}
//ジニ不純度を計算する関数
function calcGini(rows) {
var total = rows.length;
var sum = 0;
var uniques = uniqueCounts(rows);
Object.keys(uniques).forEach(function(key) {
sum += (uniques[key]/total) * (uniques[key]/total);
});
return 1-sum;
}
//任意の特徴量で、任意の値で二つに分類する関数
function divide(rows, column, value) {
var splitFunc;
var notSplitFunc;
if(typeof value === 'number') {
splitFunc = function(row) { return row[column] > value; };
notSplitFunc = function(row) { return !(row[column] > value); };
} else {
splitFunc = function(row) { return row[column] === value; }
notSplitFunc = function(row) { return !(row[column] === value); }
}
var trueSet = rows.filter(splitFunc);
var falseSet = rows.filter(notSplitFunc);
return [trueSet, falseSet];
}
// 決定木を構築
function buildTree(rows) {
var currentGini = calcGini(rows);
var subResult = uniqueCounts(rows);
var bestCol;
var bestVal;
var bestGini = 1;
var bestSets = new Array();
var cols = rows[0].length-1;
for(var col=0; col < cols; col++) {
var min = parseFloat(_.min(rows, function(row) { return row[col]; })[2]);
var max = parseFloat(_.max(rows, function(row) { return row[col]; })[2]);
var boundaries = _.uniq(rows.map(function(row) { return parseFloat(row[col]); }))
.sort(function(pre, after) { return pre > after ? 1 : -1; });
var diffs = new Array();
for(var i=0; i < boundaries.length-1; i++) {
var diff = (boundaries[i+1]+boundaries[i])/2;
diffs.push(diff.toFixed(2));
}
diffs.forEach(function(diff) {
var dividedSets = divide(rows, col, parseFloat(diff));
if(!(dividedSets[0].length === 0) && !(dividedSets[1].lentgh === 0)) {
var gini = ( calcGini(dividedSets[0]) + calcGini(dividedSets[1]) ) / 2;
if(gini < bestGini) {
bestCol = col;
bestVal = diff;
bestGini = gini;
bestSets[0] = dividedSets[0];
bestSets[1] = dividedSets[1];
}
}
});
}
if(bestGini < currentGini && bestSets[0].length > 1 && bestSets[1].length > 1) {
var trueBranch = buildTree(bestSets[0]);
var falseBranch = buildTree(bestSets[1]);
return new DecisionNode(bestCol, bestVal, subResult, trueBranch, falseBranch, false);
} else {
return new DecisionNode(null, null, subResult, null, null, true);
}
}
//生成した木をコンソールに出力
function printTree(tree, indent) {
if (indent === null) indent = '';
if (tree.isEnd) {
console.log(tree.results);
} else {
console.log(tree.col+":"+tree.val);
sys.print(indent+'T->');
printTree(tree.tb, indent+' ');
sys.print(indent+'F->');
printTree(tree.fb, indent+' ');
}
}
//決定木描画用のJSONを生成
function createJSON(tree, attributes) {
if(tree.isEnd) {
var object = new Array()
Object.keys(tree.results).forEach(function(key) {
object.push({name: key+": "+tree.results[key]})
})
if(object.length === 1) return object[0]
else return object
} else {
var object = {
name: attributes[tree.col]+" <= "+tree.val,
}
object.children = new Array()
var fb = createJSON(tree.fb, attributes);
if(fb && Array.isArray(fb)) {
fb.forEach(function(obj) {
object.children.push(obj)
})
} else {
object.children.push(fb)
}
var tb = createJSON(tree.tb, attributes);
if(tb && Array.isArray(tb)) {
tb.forEach(function(obj) {
object.children.push(obj)
})
} else {
object.children.push(tb)
}
return object
}
}
/*
* ******
* main *
* *******
*/
(function() {
var csv = new CSV(fs.readFileSync('../source/iris.data', 'utf-8')).parse();
var attributes = csv[0];
var data = csv.splice(1, csv.length);
var tree = buildTree(data);
var resultJSON = JSON.stringify(createJSON(tree, attributes), null, '\t')
fs.writeFile('../source/result.json', resultJSON, function(err) {
if(err) console.error(err)
})
printTree(tree, '');
})();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment