Skip to content

Instantly share code, notes, and snippets.

@kennethjmyers
Forked from michiexile/gap.py
Last active February 26, 2016 21:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kennethjmyers/d46f88d9ebc13eddf65a to your computer and use it in GitHub Desktop.
Save kennethjmyers/d46f88d9ebc13eddf65a to your computer and use it in GitHub Desktop.
A Python implementation of the Gap Statistic from Tibshirani, Walther, Hastie to determine the inherent number of clusters in a dataset with k-means clustering.
# gap.py
# (c) 2013 Mikael Vejdemo-Johansson
# BSD License
#
# SciPy function to compute the gap statistic for evaluating k-means clustering.
# Gap statistic defined in
# Tibshirani, Walther, Hastie:
# Estimating the number of clusters in a data set via the gap statistic
# J. R. Statist. Soc. B (2001) 63, Part 2, pp 411-423
import scipy
import scipy.cluster.vq
import scipy.spatial.distance
dst = scipy.spatial.distance.euclidean
def gap(data, refs=None, nrefs=20, ks=range(1,11)):
"""
Compute the Gap statistic for an nxm dataset in data.
Either give a precomputed set of reference distributions in refs as an (n,m,k) scipy array,
or state the number k of reference distributions in nrefs for automatic generation with a
uniformed distribution within the bounding box of data.
Give the list of k-values for which you want to compute the statistic in ks.
"""
shape = data.shape
if refs==None:
tops = data.max(axis=0)
bots = data.min(axis=0)
dists = scipy.matrix(scipy.diag(tops-bots))
rands = scipy.random.random_sample(size=(shape[0],shape[1],nrefs))
for i in range(nrefs):
rands[:,:,i] = rands[:,:,i]*dists+bots
else:
rands = refs
gaps = scipy.zeros((len(ks),))
for (i,k) in enumerate(ks):
(kmc,kml) = scipy.cluster.vq.kmeans2(data, k)
disp = sum([dst(data[m,:],kmc[kml[m],:]) for m in range(shape[0])])
refdisps = scipy.zeros((rands.shape[2],))
for j in range(rands.shape[2]):
(kmc,kml) = scipy.cluster.vq.kmeans2(rands[:,:,j], k)
refdisps[j] = sum([scipy.log(dst(rands[m,:,j],kmc[kml[m],:])) for m in range(shape[0])])
gaps[i] = scipy.mean(refdisps)-scipy.log(disp)
return gaps
@kennethjmyers
Copy link
Author

Changed line 48 to reflect the equation in this paper

@kennethjmyers
Copy link
Author

further corrected lines 47 and 48 according to the above paper

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment