Skip to content

Instantly share code, notes, and snippets.

@mimno
Last active August 29, 2015 14:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mimno/b2c8d076d77719e605d7 to your computer and use it in GitHub Desktop.
Save mimno/b2c8d076d77719e605d7 to your computer and use it in GitHub Desktop.

Here are some points, which happen to have been sampled from six round Gaussian distributions. Can we figure out where the centers of those Gaussians were, and which points came from which cluster? Theoretically, this is a hard problem. There's no way to know that we have the best clustering without checking all possible assignments of points to clusters.

We use an iterative algorithm, k-means. Rather than solving the hard problem of finding the best cluster assignments, this algorithm alternates between two easy problems: finding pairs of points that are closest to each other, and calculating an average.

Click the "Random Clusters" button to drop in six random cluster centers. The light gray lines represent the "continental divide" between clusters: any point in the region associated with a cluster is closer to that cluster than any other cluster.

Click the "Cluster Points" button to assign each point to its nearest cluster. You should see cluster numbers appear. Each time the assignment of a cluster changes, the number will be highlighted for a second.

Click the "Move Clusters" button to shift the clusters so that they are at the centroid of the points they are responsible for. This just averages the x and y coordinates of those points.

Now alternate between clicking those two buttons. If clusters lose all their points, I'm moving them to a random location. Eventually, you might reach a state where no points want to change clusters, so the clusters don't move. The algorithm has converged. This may not be the best overall cluster assignment, but it's the best local cluster assignment.

<html>
<head>
<script src="http://d3js.org/d3.v3.min.js"></script>
<style>
body { font-family: "Open Sans", Calibri, Verdana; }
</style>
</head>
<body>
<div id="header"><div>An iterative algorithm: k-means</div>
<!-- Some interactive controls -->
<button id="addClusters">Random Clusters</button>
<button id="clusterPoints">Cluster Points</button>
<button id="moveMeans">Move Clusters</button>
</div>
<div id="chart"></div>
<script>
var height = 500;
var width = 600;
var svg = d3.select("#chart").append("svg").attr("height", height).attr("width", width).append("g");
var xScale = d3.scale.linear().domain([0, 50]).range([0, width]);
var yScale = d3.scale.linear().domain([0, 50]).range([height, 0]);
// A voronizer function
var voronoi = d3.geom.voronoi()
.x(function (cluster) { return xScale(cluster.x); })
.y(function (cluster) { return yScale(cluster.y); })
.clipExtent([[0,0], [width, height]]);
// A helper function for plotting Voronoi polygons
function polygonPath(d) {
return "M" + d.join("L") + "Z";
}
// The data array for cluster centers.
// We'll make six clusters, in random locations away from the edges of the plot.
var clusters = [];
var numClusters = 6;
// Each cluster is represented by a <text> and a <path> element.
var clusterTexts;
var clusterBorders;
// The data array for points to be clustered.
var points = [];
var pointTexts;
var randomNormal = d3.random.normal(0, 2);
for (var cluster = 0; cluster < 6; cluster++) {
var clusterCenter = { x: 40 * Math.random() + 5, y: 40 * Math.random() + 5 };
for (point = 0; point < 10; point++) {
points.push({ id: points.length, x: clusterCenter.x + randomNormal(), y: clusterCenter.y + randomNormal(), cluster: "" });
}
}
// Add circles for each point
svg.selectAll("circle").data(points).enter().append("circle")
.style("fill", "#ccc").style("opacity", 0.7)
.attr("cx", function (p) { return xScale(p.x); })
.attr("cy", function (p) { return yScale(p.y); })
.attr("r", 4);
// Create <text> elements for each point, showing
// their current assignment.
pointTexts = svg.selectAll(".point").data(points);
pointTexts.enter().append("text").attr("class", "point");
pointTexts
.attr("id", function (point) { return "p" + point.id; })
.attr("x", function (point) { return xScale(point.x); } )
.attr("y", function (point) { return yScale(point.y); } )
.style("font-size", "small")
.text(function (point) { return point.cluster; });
// Now we define three functions, which correspond
// to initialization and the two steps of the iterative algorithm:
// Initialization: place clusters randomly
d3.select("#addClusters").on("click", function() {
clusters = [];
for (var i = 0; i < numClusters; i++) {
clusters.push( { x: 50 * Math.random(), y: 50 * Math.random(), n: 0 } );
}
// Create <text> elements for each cluster.
var texts = svg.selectAll(".cluster").data(clusters);
texts.enter().append("text").attr("class", "cluster");
texts
.attr("x", function (cluster) { return xScale(cluster.x); } )
.attr("y", function (cluster) { return yScale(cluster.y); } )
.text(function (cluster, i) { return i; } )
.style("fill", "red")
.style("font-size", "xx-large");
clusterBorders = svg.selectAll(".clusterborder").data(voronoi(clusters));
clusterBorders.enter().append("path").attr("class", "clusterborder")
.style("fill", "none").style("stroke", "#cccccc");
clusterBorders.attr("d", polygonPath);
clusterTexts = texts;
});
// Iterative step 1: Move the cluster centers based on the current
// assignment of points to clusters.
d3.select("#moveMeans").on("click", function() {
// Reset the clusters
clusters.forEach( function (cluster) {
cluster.x = 0; cluster.y = 0; cluster.n = 0;
} );
// Sum up x and y coordinates
points.forEach( function (point) {
clusters[ point.cluster ].x += point.x;
clusters[ point.cluster ].y += point.y;
clusters[ point.cluster ].n ++;
});
// Divide to get the average
clusters.forEach( function (cluster) {
// Move empty clusters to random locations
if (cluster.n === 0) {
cluster.x = 40 * Math.random() + 5;
cluster.y = 40 * Math.random() + 5;
}
else {
cluster.x /= cluster.n;
cluster.y /= cluster.n;
}
});
// Now change the underlying data and move the
// <text> elements.
clusterTexts.data(clusters);
clusterTexts.transition()
.attr("x", function (cluster) { return xScale(cluster.x); } )
.attr("y", function (cluster) { return yScale(cluster.y); } );
clusterBorders.data(voronoi(clusters));
clusterBorders.transition()
.attr("d", polygonPath);
});
// Iterative step 2: Reassign points to their nearest cluster center.
d3.select("#clusterPoints").on("click", function () {
// Consider every point...
points.forEach( function (point, i) {
var shortestDistance = Number.POSITIVE_INFINITY;
// Find the nearest cluster...
var newCluster = -1;
clusters.forEach( function (cluster, id) {
var distance = Math.sqrt(Math.pow(point.x - cluster.x, 2) + Math.pow(point.y - cluster.y, 2));
if (distance < shortestDistance) {
shortestDistance = distance;
newCluster = id;
}
} );
if (point.cluster != newCluster) {
d3.select("#p" + point.id).transition().duration(1000)
.style("font-size", "x-large");
}
point.cluster = newCluster;
} );
// Update the text of the points to reflect the
// new cluster assignments.
pointTexts.data(points);
pointTexts.transition().delay(1000).duration(1000)
.text(function (point) { return point.cluster; });
pointTexts.transition().delay(2000).duration(1000)
.style("font-size", "small");
});
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment