Skip to content

Instantly share code, notes, and snippets.

@schnerd
Created July 29, 2017 05:16
Show Gist options
  • Save schnerd/0041305b43b7b5e9c2b9dffae449a012 to your computer and use it in GitHub Desktop.
Save schnerd/0041305b43b7b5e9c2b9dffae449a012 to your computer and use it in GitHub Desktop.
Ckmeans Algorithm Visualization [Work in Progress]
license: mit

This block is a hastily-put-together visualization of the Ckmeans algorithm. I used this algorithm to build d3-scale-cluster last year, but I never truly grasped how the underlying algorithm worked. Inspired by algorithm visualization work from others in the community, I figured the best way to understand Ckmeans would be to visualize it–and thus, this block was born.

This is very much a work in progress, and the code needs to be cleaned up significantly–just wanted to use this block for my unconf registration! Hoping to have a more fleshed out version in the coming weeks.

<!DOCTYPE html>
<head>
<meta charset="utf-8">
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="https://unpkg.com/lodash@4.17.4/lodash.min.js"></script>
<style>
body {
margin:0;position:fixed;top:0;right:0;bottom:0;left:0;
font-family: Helvetica, Arial, sans-serif;
}
.matrix-title {
font-size: 12px;
fill: #999;
}
</style>
</head>
<body>
<script>
const CELL_WIDTH = 30;
const CELL_HEIGHT = 12;
const SLEEP_MS = 20;
const CELL_STROKE_COLOR = 'rgba(0, 0, 0, 0.05)';
const nClusters = 4;
// Feel free to change or delete any of the code you see in this editor!
var svg = d3.select('body').append('svg')
.attr("width", 960)
.attr("height", 500)
const $vector =
svg
.append('g')
.attr('id', 'vector')
.attr('transform', 'translate(20, 50)');
$vector
.append('text')
.classed('matrix-title', true)
.attr('y', -16)
.text('Input');
let scale;
const $matrix =
svg
.append('g')
.attr('id', 'matrix')
.attr('transform', 'translate(280, 50)');
$matrix
.append('text')
.classed('matrix-title', true)
.attr('y', -16)
.text('Sum-of-squares Matrix');
const $backtrackMatrix =
svg
.append('g')
.attr('id', 'backtrackMatrix')
.attr('transform', 'translate(80, 50)');
$backtrackMatrix
.append('text')
.classed('matrix-title', true)
.attr('y', -16)
.text('Backtrack Matrix');
function sleep(ms) {
ms = ms || SLEEP_MS;
return new Promise(resolve => setTimeout(resolve, ms));
}
d3.selection.prototype.moveToFront = function() {
return this.each(function(){
this.parentNode.appendChild(this);
});
};
function formatValue(value) {
if (value < 0) {
return `-${formatValue(-value)}`;
}
let suffix = '';
if (value >= 1e12) {
suffix = 'T';
value = value / 1e12;
} else if (value >= 1e9) {
suffix = 'B';
value = value / 1e9;
} else if (value >= 1e6) {
suffix = 'M';
value = value / 1e6;
} else if (value >= 1e3) {
suffix = 'k';
value = value / 1e3;
}
/**
* In the below logic, casting to number and back to string is
* used to trim trailing zeros
*/
if (value >= 1) {
return Number(value.toPrecision(3)).toString() + suffix;
}
// Numbers less than 1 - substr to remove leading 0
if (value >= 0.001) {
return Number(value.toFixed(3)).toString().substr(1);
}
if (value >= 0.0001) {
return Number(value.toPrecision(4)).toString().substr(1);
}
if (value > 0) {
return value.toExponential(1);
}
return '0';
}
function numericSort (array) {
return array
// ensure the array is not changed in-place
.slice()
// comparator function that treats input as numeric
.sort(function (a, b) {
return a - b;
});
}
function uniqueCountSorted (input) {
var uniqueValueCount = 0;
var lastSeenValue;
for (var i = 0; i < input.length; i++) {
if (i === 0 || input[i] !== lastSeenValue) {
lastSeenValue = input[i];
uniqueValueCount++;
}
}
return uniqueValueCount;
}
function makeMatrix (columns, rows) {
var matrix = [];
for (var i = 0; i < columns; i++) {
var column = [];
for (var j = 0; j < rows; j++) {
column.push(0);
}
matrix.push(column);
}
return matrix;
}
function ssq (j, i, sumX, sumXsq) {
var sji; // s(j, i)
if (j > 0) {
var muji = (sumX[i] - sumX[j - 1]) / (i - j + 1); // mu(j, i)
sji = sumXsq[i] - sumXsq[j - 1] - (i - j + 1) * muji * muji;
} else {
sji = sumXsq[i] - sumX[i] * sumX[i] / (i + 1);
}
return sji < 0 ? 0 : sji;
}
async function fillMatrixColumn (imin, imax, column, matrix, backtrackMatrix, sumX, sumXsq) {
if (imin > imax) {
return;
}
// Start at midpoint between imin and imax
var i = Math.floor((imin + imax) / 2);
// Initialization of S[k][i]:
matrix[column][i] = matrix[column - 1][i - 1];
// TODO better animation of above?
await drawMatrixCell($matrix, matrix, column, i, { flash: true, color: true });
await sleep();
backtrackMatrix[column][i] = i;
await drawMatrixCell($backtrackMatrix, backtrackMatrix, column, i, { flash: true, color: true });
await sleep();
var jlow = column; // the lower end for j
if (imin > column) {
jlow = Math.max(jlow, backtrackMatrix[column][imin - 1] || 0);
}
jlow = Math.max(jlow, backtrackMatrix[column - 1][i] || 0);
var jhigh = i - 1; // the upper end for j
if (imax < matrix.length - 1) {
jhigh = Math.min(jhigh, backtrackMatrix[column][imax + 1] || 0);
}
var sji;
var sjlowi;
var ssqjlow;
var ssqj;
for (var j = jhigh; j >= jlow; --j) {
// compute s(j,i)
sji = ssq(j, i, sumX, sumXsq);
// MS May 11, 2016 Added:
if (sji + matrix[column - 1][jlow - 1] >= matrix[column][i]) {
break;
}
// Examine the lower bound of the cluster border
// compute s(jlow, i)
sjlowi = ssq(jlow, i, sumX, sumXsq);
ssqjlow = sjlowi + matrix[column - 1][jlow - 1];
if (ssqjlow < matrix[column][i]) {
// shrink the lower bound
matrix[column][i] = ssqjlow;
backtrackMatrix[column][i] = jlow;
await drawMatrixCell($matrix, matrix, column, i, { flash: true, color: true });
await sleep();
await drawMatrixCell($backtrackMatrix, backtrackMatrix, column, i, { flash: true, color: true });
await sleep();
}
jlow++;
ssqj = sji + matrix[column - 1][j - 1];
if (ssqj < matrix[column][i]) {
matrix[column][i] = ssqj;
backtrackMatrix[column][i] = j;
await drawMatrixCell($matrix, matrix, column, i, { flash: true, color: true });
await sleep();
await drawMatrixCell($backtrackMatrix, backtrackMatrix, column, i, { flash: true, color: true });
await sleep();
}
}
await fillMatrixColumn(imin, i - 1, column, matrix, backtrackMatrix, sumX, sumXsq);
await fillMatrixColumn(i + 1, imax, column, matrix, backtrackMatrix, sumX, sumXsq);
}
async function fillMatrices (data, matrix, backtrackMatrix) {
var nValues = matrix[0].length;
var sumX = new Array(nValues);
var sumXsq = new Array(nValues);
// Use the median to shift values of x to improve numerical stability
var shift = data[Math.floor(nValues / 2)];
// Initialize first row in matrix & backtrackMatrix
for (var i = 0; i < nValues; ++i) {
if (i === 0) {
sumX[0] = data[0] - shift;
sumXsq[0] = (data[0] - shift) * (data[0] - shift);
} else {
sumX[i] = sumX[i - 1] + data[i] - shift;
sumXsq[i] = sumXsq[i - 1] + (data[i] - shift) * (data[i] - shift);
}
// Initialize for k = 0
matrix[0][i] = ssq(0, i, sumX, sumXsq);
backtrackMatrix[0][i] = 0;
await drawMatrixCell($matrix, matrix, 0, i, { flash: true, color: true });
await sleep();
}
// Initialize the rest of the columns
var imin;
for (var k = 1; k < matrix.length; ++k) {
if (k < matrix.length - 1) {
imin = k;
} else {
// No need to compute matrix[K-1][0] ... matrix[K-1][N-2]
imin = nValues - 1;
}
await fillMatrixColumn(imin, nValues - 1, k, matrix, backtrackMatrix, sumX, sumXsq);
}
}
function getMaxValue(data, nValues) {
var sumX = new Array(nValues);
var sumXsq = new Array(nValues);
// Use the median to shift values of x to improve numerical stability
var shift = data[Math.floor(nValues / 2)];
// Initialize first row in matrix & backtrackMatrix
const ssqs = [];
for (var i = 0; i < nValues; ++i) {
if (i === 0) {
sumX[0] = data[0] - shift;
sumXsq[0] = (data[0] - shift) * (data[0] - shift);
} else {
sumX[i] = sumX[i - 1] + data[i] - shift;
sumXsq[i] = sumXsq[i - 1] + (data[i] - shift) * (data[i] - shift);
}
// Initialize for k = 0
ssqs.push(Math.abs(ssq(0, i, sumX, sumXsq)));
}
return _.max(ssqs);
}
function drawMatrix($el, matrix, cellWidth = CELL_WIDTH, cellHeight = CELL_HEIGHT) {
const nCols = matrix.length;
const nRows = matrix[0].length;
matrix = _.map(matrix, (col, i) => (
col.map((row, j) => ({number: row, col: i, row: j, key: `${i}-${j}`}))
));
const cells = _.flatten(matrix);
const rect = $el.selectAll('.matrix-rect')
.data(cells, cell => cell.key)
.enter()
.append('rect')
.classed('matrix-rect', true)
.attr('data-key', cell => cell.key)
.attr('x', cell => cellWidth * cell.col)
.attr('y', cell => cellHeight * cell.row)
.attr('fill', '#eee')
.attr('stroke', CELL_STROKE_COLOR)
.attr('strokeWidth', 1)
.attr('width', cellWidth)
.attr('height', cellHeight);
const text = $el.selectAll('.matrix-label')
.data(cells, cell => cell.key)
.enter()
.append('text')
.style('font-size', '10px')
.classed('matrix-label', true)
.attr('data-key', cell => cell.key)
.attr('fill', 'rgba(0, 0, 0, 0.4)')
.attr('x', cell => cellWidth * (cell.col + 0.5))
.attr('y', cell => cellHeight * (cell.row + 0.6))
.attr('alignment-baseline', 'middle')
.attr('text-anchor', 'middle')
text
.text(row => formatValue(row.number));
}
function getRectElement($el, col, row) {
return $el.select(`.matrix-rect[data-key="${col}-${row}"]`);
}
function getTextElement($el, col, row) {
return $el.select(`.matrix-label[data-key="${col}-${row}"]`);
}
function drawMatrixCell($el, matrix, col, row, action = {}) {
const cellRect = getRectElement($el, col, row);
const cellText = getTextElement($el, col, row);
cellRect.moveToFront();
cellText.moveToFront()
if (action.highlight || action.flash) {
cellRect.interrupt().attr('stroke', '#ff7b2e');
}
if (action.flash) {
cellRect.transition().attr('stroke', CELL_STROKE_COLOR);
}
if (action.color) {
cellRect.attr('fill', cell => scale(matrix[col][row]));
}
cellText.text(formatValue(matrix[col][row]));
}
function separateCluster($el, cluster, separateAtIndex) {
$el.selectAll('.matrix-rect,.matrix-label')
.filter(cell => {
return cell.row >= separateAtIndex;
})
.transition()
.attr('y', function(cell) {
return Number(d3.select(this).attr('y')) + 20;
});
}
function drawArrow($el, id, isVertical, index) {
const arrowSize = 12;
let arrow = $el.select(`#${id}`);
let x = isVertical ? index * CELL_WIDTH + (CELL_WIDTH - arrowSize) / 2 : -arrowSize;
let y = isVertical ? -arrowSize : (index - 1) * CELL_HEIGHT + (CELL_HEIGHT - arrowSize) / 2;
const transform =
isVertical
? `translate(${x}, ${y}) rotate(90 6 6)`
: `translate(${x}, ${y})`
if (arrow.empty()) {
arrow = $el.append('g')
arrow
.classed('arrow', true)
.style('opacity', 1)
.attr('id', id)
.attr('transform', transform);
arrow.html(`<svg height="12" viewBox="0 0 48 48" width="12" xmlns="http://www.w3.org/2000/svg"><path d="M0 0h48v48h-48z" fill="none"/><path d="M24 8l-2.83 2.83 11.17 11.17h-24.34v4h24.34l-11.17 11.17 2.83 2.83 16-16z"/></svg>`);
}
arrow
.transition()
.attr('transform', transform)
.style('opacity', 1);
}
/**
* Ckmeans clustering is an improvement on heuristic-based clustering
* approaches like Jenks. The algorithm was developed in
* [Haizhou Wang and Mingzhou Song](http://journal.r-project.org/archive/2011-2/RJournal_2011-2_Wang+Song.pdf)
* as a [dynamic programming](https://en.wikipedia.org/wiki/Dynamic_programming) approach
* to the problem of clustering numeric data into groups with the least
* within-group sum-of-squared-deviations.
*
* Minimizing the difference within groups - what Wang & Song refer to as
* `withinss`, or within sum-of-squares, means that groups are optimally
* homogenous within and the data is split into representative groups.
* This is very useful for visualization, where you may want to represent
* a continuous variable in discrete color or style groups. This function
* can provide groups that emphasize differences between data.
*
* Being a dynamic approach, this algorithm is based on two matrices that
* store incrementally-computed values for squared deviations and backtracking
* indexes.
*
* Unlike the [original implementation](https://cran.r-project.org/web/packages/Ckmeans.1d.dp/index.html),
* this implementation does not include any code to automatically determine
* the optimal number of clusters: this information needs to be explicitly
* provided.
*
* ### References
* _Ckmeans.1d.dp: Optimal k-means Clustering in One Dimension by Dynamic
* Programming_ Haizhou Wang and Mingzhou Song ISSN 2073-4859
*
* from The R Journal Vol. 3/2, December 2011
* @param {Array<number>} data input data, as an array of number values
* @param {number} nClusters number of desired classes. This cannot be
* greater than the number of values in the data array.
* @returns {Array<Array<number>>} clustered input
* @example
* ckmeans([-1, 2, -1, 2, 4, 5, 6, -1, 2, -1], 3);
* // The input, clustered into groups of similar numbers.
* //= [[-1, -1, -1, -1], [2, 2, 2], [4, 5, 6]]);
*/
async function ckmeans (data, nClusters) {
if (nClusters > data.length) {
throw new Error('Cannot generate more classes than there are data values');
}
var nValues = data.length;
var sorted = numericSort(data);
// we'll use this as the maximum number of clusters
var uniqueCount = uniqueCountSorted(sorted);
// if all of the input values are identical, there's one cluster
// with all of the input in it.
if (uniqueCount === 1) {
return [sorted];
}
nClusters = Math.min(uniqueCount, nClusters);
// named 'S' originally
var matrix = makeMatrix(nClusters, nValues);
// named 'J' originally
var backtrackMatrix = makeMatrix(nClusters, nValues);
scale = d3.scalePow()
.exponent(0.3)
.domain([0, getMaxValue(sorted, nValues)])
.range(['#ffffb2', '#bd0026']);
drawMatrix($vector, [data]);
drawMatrix($matrix, matrix);
drawMatrix($backtrackMatrix, backtrackMatrix);
// This is a dynamic programming way to solve the problem of minimizing
// within-cluster sum of squares. It's similar to linear regression
// in this way, and this calculation incrementally computes the
// sum of squares that are later read.
await fillMatrices(sorted, matrix, backtrackMatrix);
await sleep(250);
// The real work of Ckmeans clustering happens in the matrix generation:
// the generated matrices encode all possible clustering combinations, and
// once they're generated we can solve for the best clustering groups
// very quickly.
var clusters = [];
var clusterRight = backtrackMatrix[0].length - 1;
await drawArrow($backtrackMatrix, 'col-arrow', true, nClusters - 1);
await drawArrow($backtrackMatrix, 'row-arrow', false, nValues);
await sleep(1000);
// Backtrack the clusters from the dynamic programming matrix. This
// starts at the bottom-right corner of the matrix (if the top-left is 0, 0),
// and moves the cluster target with the loop.
for (var cluster = backtrackMatrix.length - 1; cluster >= 0; cluster--) {
var clusterLeft = backtrackMatrix[cluster][clusterRight];
// fill the cluster from the sorted input by taking a slice of the
// array. the backtrack matrix makes this easy - it stores the
// indexes where the cluster should start and end.
clusters[cluster] = sorted.slice(clusterLeft, clusterRight + 1);
if (cluster > 0) {
drawMatrixCell($backtrackMatrix, backtrackMatrix, cluster, clusterRight, { highlight: true });
await sleep(1000);
await drawArrow($backtrackMatrix, 'row-arrow', false, clusterLeft);
await sleep(1000);
separateCluster($vector, cluster, clusterLeft);
separateCluster($backtrackMatrix, cluster, clusterLeft);
if (cluster > 1) {
await sleep(1000);
await drawArrow($backtrackMatrix, 'col-arrow', true, cluster - 1);
}
await sleep(1000);
clusterRight = clusterLeft - 1;
}
}
// Reset all arrows and highlights
$backtrackMatrix.selectAll('.arrow').transition().style('opacity', 0);
$backtrackMatrix.selectAll('.matrix-rect').transition().attr('stroke', CELL_STROKE_COLOR);
await sleep(600);
// Hide other matrices, just show vector
$backtrackMatrix.transition().style('opacity', 0);
$matrix.transition().style('opacity', 0);
await sleep(600);
$vector
.transition()
.attr('transform', 'translate(250, 50)');
return clusters;
}
ckmeans([1, 2, 4, 5, 43, 52, 62, 77, 100, 111, 123, 234, 724, 788, 796, 1004, 1026, 1082, 1244], nClusters);
</script>
</body>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment