Skip to content

Instantly share code, notes, and snippets.

Created July 14, 2014 05:59
Show Gist options
  • Save hardbyte/ded34566f6fb704264b4 to your computer and use it in GitHub Desktop.
Save hardbyte/ded34566f6fb704264b4 to your computer and use it in GitHub Desktop.
K-means with D3js
<!DOCTYPE html>
<html lang="en">
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Latest compiled and minified CSS -->
<link rel="stylesheet" href="//">
<!-- Optional theme -->
<link rel="stylesheet" href="//">
path {
stroke: #a51314;
fill: none;
} {
fill: steelblue;
pointer-events: none;
circle.means {
fill: red;
opacity: 0.3;
} {
stroke: grey;
stroke-width: "2px";
<div class="container-fluid">
<div class="row">
<div id="vis"></div>
<script src=""></script>
<script src="//"></script>
<script src=""></script>
<script src=""></script>
var margin = {top: 20, right: 20, bottom: 30, left: 50},
width = 960 - margin.left - margin.right,
height = 500 - - margin.bottom;
var X = d3.scale.linear()
.range([0, width]);
var Y = d3.scale.linear()
.range([height, 0]);
var voronoi = d3.geom.voronoi()
.clipExtent([[0, 0], [width, height]]);
var svg ="#vis").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + + margin.bottom)
.attr("transform", "translate(" + margin.left + "," + + ")");
var k = 4,
numSamplesPerFrame = 10,
numSamplesPerCluster = 200;
var data;
var xdata = [],
ydata = [],
cdata = [];
var x_means = [],
y_means = [];
// Draw the normal axes
var components = svg.selectAll("")
[[0.0, 0.5], [1, 0.5], "X"],
[[0.5, 0.0], [0.5, 1], "Y"],
], function(d, i){return d[2];});
.attr('class', 'ax')
.attr('x1', function (d) { return X(d[0][0]); })
.attr('y1', function (d) { return Y(d[0][1]); })
.attr('x2', function (d) { return X(d[1][0]); })
.attr('y2', function (d) { return Y(d[1][1]); });
function rnd(mean, std){
var r = 0;
for (var i = 0; i < 10; i++) {
r += Math.random() * 2 - 1
return r * std + mean;
function kmeans(){
// Step 1, choose k random starting positions
for(i = 0; i < k; i++){
x_means[i] = Math.random();
y_means[i] = Math.random();
var numSteps = 0;
function step() {
var path = svg.selectAll("path");
function redraw() {
var d = [];
for (var i = 0; i < k; i++) {
d.push([X(x_means[i]), Y(y_means[i])]);
var vd = voronoi(d);
var v = path
.data(vd, polygon);
.attr("d", polygon).order()
function polygon(d) {
return "M" + d.join("L") + "Z";
// For each point calculate the nearest mean
// TODO partitioning the observations according to the Voronoi diagram generated by the means
for (var i = 0; i < xdata.length; i++) {
var nearestDistance = 9999999999;
for (var j = 0; j < k; j++) {
var distance = Math.pow( xdata[i] - x_means[j], 2) + Math.pow( ydata[i] - y_means[j], 2);
if(distance < nearestDistance){
nearestDistance = distance;
cdata[i] = j;
// For each mean calculate the centroid of all points
var keepGoing = (++numSteps < 100);
for (var j = 0; j < k; j++) {
var n = 0;
var totalx = 0, totaly = 0;
for (var i = 0; i < xdata.length; i++) {
if (cdata[i] == j) {
n += 1;
totalx += xdata[i];
totaly += ydata[i];
if(n === 0){
// Not part of any clusters
n = 1;
totalx = Math.random(), totaly = Math.random();
if (totalx / n != x_means[j] || totaly / n != y_means[j]) {
x_means[j] = totalx / n;
y_means[j] = totaly / n;
keepGoing = true;
return !keepGoing;
function plotData(i){
var xycoords = numeric.transpose([xdata, ydata]).slice(0, i);
var circle = svg.selectAll("")
.attr('class', 'data')
.attr("r", 1);
.attr("cx", function(d, i) { return X(d[0]); })
.attr("cy", function(d, i){return Y(d[1]);});
function plotMeans(){
var circle = svg.selectAll("circle.means")
.data(numeric.transpose([x_means, y_means]));
.attr('class', 'means')
.attr("r", 10);
.attr("cx", function(d, i) { return X(d[0]); })
.attr("cy", function(d, i){return Y(d[1]);});
function lim(val, min, max){
if(val < min){
return min;
if(val > max){
return max;
return val;
function createData() {
for(var cluster = 0; cluster < k; ++cluster) {
var amean = rnd(0.5, 0.1);
var bmean = rnd(0.5, 0.1);
var astd = rnd(0.03, 0.01);
//var bstd = rnd(0.02, 0.02);
var ax = rnd(0.98, 0.01);
var ay = rnd(0.02, 0.01);
var bx = 1 - ax;
var by = 1 - ay;
for (var i = 0; i < numSamplesPerCluster; ++i) {
var a = rnd(amean, astd),
b = rnd(bmean, astd);
var x = lim(ax * a + bx * b, 0, 1),
y = lim(ay * a + by * b, 0, 1);
var numFrames = 0;
d3.timer(function () {
plotData(numSamplesPerFrame * numFrames);
if( (++numFrames) * numSamplesPerFrame > k * numSamplesPerCluster){
d3.timer(function () {
return step();
return true;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment