Last active
February 20, 2016 08:47
-
-
Save himetani/2b0a5bed69ddfa3c21ca to your computer and use it in GitHub Desktop.
decision-tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//生成した描画用のJSONデータをコピペ | |
var data = { | |
"name": "petalLength <= 2.45", | |
"children": [ | |
{ | |
"name": "setosa: 50" | |
}, | |
{ | |
"name": "petalWidth <= 1.75", | |
"children": [ | |
{ | |
"name": "petalLength <= 5.35", | |
"children": [ | |
{ | |
"name": "versicolor: 49" | |
}, | |
{ | |
"name": "virginica: 3" | |
}, | |
{ | |
"name": "virginica: 2" | |
} | |
] | |
}, | |
{ | |
"name": "versicolor: 1" | |
}, | |
{ | |
"name": "virginica: 45" | |
} | |
] | |
} | |
] | |
} | |
function drawTree(target, data) { | |
// Tree layout | |
var tree = d3.layout.tree().size([1000, 700]).separation(function() { return 1 }); | |
// Funciton of creating a link | |
var diagonal = d3.svg.diagonal(); | |
} | |
function children(d) { | |
return d["children"]; | |
} | |
var width = 1000; | |
var height = 1000; | |
var tree = d3.layout.tree().size([400, 400]).children(children); | |
var nodes = tree.nodes(data); | |
var links = tree.links(nodes); | |
var svg = d3.select("#tree") | |
.append("svg") | |
.attr("width", width) | |
.attr("height", height) | |
.append("g") | |
.attr("transform", "translate(200,30)"); | |
var node = svg.selectAll(".node") | |
.data(nodes) | |
.enter() | |
.append("g") | |
.attr("class", "node") | |
.attr("transform", function(d) { | |
return "translate(" + d.x + "," + d.y + ")"; }); | |
node.append("circle") | |
.attr("r", 10) | |
.attr("y", function(d) { return d.y+4;}) | |
.attr("stroke", "black") | |
.attr("fill", function(d) { | |
return d.children || d._children ? "black" : "white" | |
}) | |
node.append("text") | |
.text(function(d) { return d.name}) | |
.style("font-size", "10px") | |
.attr("x", -40) | |
.attr("y", function(d) { | |
return d.children || d._children ? -15 : 20 | |
}) | |
var diagonal = d3.svg.diagonal() | |
.projection(function(d) { return [d.x, d.y];}); | |
svg.selectAll(".link") | |
.data(links) | |
.enter() | |
.append("path") | |
.attr("class", "link") | |
.attr("fill", "none") | |
.attr("stroke", "blue") | |
.attr("d",diagonal); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="utf-8"> | |
<script src="http://d3js.org/d3.v3.min.js" charset="utf-8"></script> | |
<script src="https://code.jquery.com/jquery-git2.min.js" charset="utf-8"></script> | |
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.4/css/bootstrap.min.css"> | |
<style> | |
.axis path, | |
.axis line { | |
fill: none; | |
stroke: black; | |
shape-rendering: crispEdges; | |
} | |
.axis text { | |
font-family: sans-serif; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="jumbotron"> | |
<h1>Decision Tree</h1> | |
</div> | |
<div id="tree"></div> | |
</div> | |
<script type="text/javascript" src="./js/draw-tree.js"></script> | |
</body> | |
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'use strict'; | |
var CSV = require("comma-separated-values"); | |
var fs = require("fs"); | |
var _ = require("underscore"); | |
var sys = require("sys"); | |
//決定木のノード(コンストラクタ) | |
function DecisionNode(col, val, results, tb, fb, isEnd) { | |
var node = {}; | |
node.col = col; | |
node.val = val; | |
node.results = results; | |
node.tb = tb; | |
node.fb = fb; | |
node.isEnd = isEnd; | |
return node; | |
} | |
//クラスとそれぞれの個数を計算する関数 | |
function uniqueCounts(rows) { | |
var results = {}; | |
if(rows.length !== 0) { | |
rows.forEach(function(row) { | |
var y = row[row.length-1]; | |
if( results[y] >= 1) results[y]++; | |
else { | |
results[y] = 1; | |
} | |
}); | |
} | |
return results; | |
} | |
//ジニ不純度を計算する関数 | |
function calcGini(rows) { | |
var total = rows.length; | |
var sum = 0; | |
var uniques = uniqueCounts(rows); | |
Object.keys(uniques).forEach(function(key) { | |
sum += (uniques[key]/total) * (uniques[key]/total); | |
}); | |
return 1-sum; | |
} | |
//任意の特徴量で、任意の値で二つに分類する関数 | |
function divide(rows, column, value) { | |
var splitFunc; | |
var notSplitFunc; | |
if(typeof value === 'number') { | |
splitFunc = function(row) { return row[column] > value; }; | |
notSplitFunc = function(row) { return !(row[column] > value); }; | |
} else { | |
splitFunc = function(row) { return row[column] === value; } | |
notSplitFunc = function(row) { return !(row[column] === value); } | |
} | |
var trueSet = rows.filter(splitFunc); | |
var falseSet = rows.filter(notSplitFunc); | |
return [trueSet, falseSet]; | |
} | |
// 決定木を構築 | |
function buildTree(rows) { | |
var currentGini = calcGini(rows); | |
var subResult = uniqueCounts(rows); | |
var bestCol; | |
var bestVal; | |
var bestGini = 1; | |
var bestSets = new Array(); | |
var cols = rows[0].length-1; | |
for(var col=0; col < cols; col++) { | |
var min = parseFloat(_.min(rows, function(row) { return row[col]; })[2]); | |
var max = parseFloat(_.max(rows, function(row) { return row[col]; })[2]); | |
var boundaries = _.uniq(rows.map(function(row) { return parseFloat(row[col]); })) | |
.sort(function(pre, after) { return pre > after ? 1 : -1; }); | |
var diffs = new Array(); | |
for(var i=0; i < boundaries.length-1; i++) { | |
var diff = (boundaries[i+1]+boundaries[i])/2; | |
diffs.push(diff.toFixed(2)); | |
} | |
diffs.forEach(function(diff) { | |
var dividedSets = divide(rows, col, parseFloat(diff)); | |
if(!(dividedSets[0].length === 0) && !(dividedSets[1].lentgh === 0)) { | |
var gini = ( calcGini(dividedSets[0]) + calcGini(dividedSets[1]) ) / 2; | |
if(gini < bestGini) { | |
bestCol = col; | |
bestVal = diff; | |
bestGini = gini; | |
bestSets[0] = dividedSets[0]; | |
bestSets[1] = dividedSets[1]; | |
} | |
} | |
}); | |
} | |
if(bestGini < currentGini && bestSets[0].length > 1 && bestSets[1].length > 1) { | |
var trueBranch = buildTree(bestSets[0]); | |
var falseBranch = buildTree(bestSets[1]); | |
return new DecisionNode(bestCol, bestVal, subResult, trueBranch, falseBranch, false); | |
} else { | |
return new DecisionNode(null, null, subResult, null, null, true); | |
} | |
} | |
//生成した木をコンソールに出力 | |
function printTree(tree, indent) { | |
if (indent === null) indent = ''; | |
if (tree.isEnd) { | |
console.log(tree.results); | |
} else { | |
console.log(tree.col+":"+tree.val); | |
sys.print(indent+'T->'); | |
printTree(tree.tb, indent+' '); | |
sys.print(indent+'F->'); | |
printTree(tree.fb, indent+' '); | |
} | |
} | |
//決定木描画用のJSONを生成 | |
function createJSON(tree, attributes) { | |
if(tree.isEnd) { | |
var object = new Array() | |
Object.keys(tree.results).forEach(function(key) { | |
object.push({name: key+": "+tree.results[key]}) | |
}) | |
if(object.length === 1) return object[0] | |
else return object | |
} else { | |
var object = { | |
name: attributes[tree.col]+" <= "+tree.val, | |
} | |
object.children = new Array() | |
var fb = createJSON(tree.fb, attributes); | |
if(fb && Array.isArray(fb)) { | |
fb.forEach(function(obj) { | |
object.children.push(obj) | |
}) | |
} else { | |
object.children.push(fb) | |
} | |
var tb = createJSON(tree.tb, attributes); | |
if(tb && Array.isArray(tb)) { | |
tb.forEach(function(obj) { | |
object.children.push(obj) | |
}) | |
} else { | |
object.children.push(tb) | |
} | |
return object | |
} | |
} | |
/* | |
* ****** | |
* main * | |
* ******* | |
*/ | |
(function() { | |
var csv = new CSV(fs.readFileSync('../source/iris.data', 'utf-8')).parse(); | |
var attributes = csv[0]; | |
var data = csv.splice(1, csv.length); | |
var tree = buildTree(data); | |
var resultJSON = JSON.stringify(createJSON(tree, attributes), null, '\t') | |
fs.writeFile('../source/result.json', resultJSON, function(err) { | |
if(err) console.error(err) | |
}) | |
printTree(tree, ''); | |
})(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment