Skip to content

Instantly share code, notes, and snippets.

@JBed
Created July 11, 2015 22:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JBed/4ca8012dad91bf055e55 to your computer and use it in GitHub Desktop.
Save JBed/4ca8012dad91bf055e55 to your computer and use it in GitHub Desktop.
k-means unsupervised pre-training in python
# http://jmlr.org/papers/volume11/erhan10a/erhan10a.pdf
import cPickle as pickle
import numpy as np
from matplotlib import pyplot as plt
from os.path import join
from sklearn.cluster import KMeans
# download data here: http://www.cs.toronto.edu/~kriz/cifar.html
with open(join('data','cifar-10-batches-py','data_batch_1'),'rb') as f:
data = pickle.load(f)
images = data['data'].reshape((-1,3,32,32)).astype('float32')/255
images = np.rollaxis(images, 1, 4)
# collect patches
patches = np.zeros((0,5,5,3))
for x in range(0,32-5,5):
for y in range(0,32-5,5):
patches = np.concatenate((patches, images[:,x:x+5,y:y+5,:]), axis=0)
patches = patches.reshape((patches.shape[0],-1))
# normalize
mu = patches.mean(axis=0)
sigma = patches.std(axis=0) + np.ptp(patches, axis=0)/20.0
patches = (patches-mu[np.newaxis,:])/(sigma[np.newaxis,:])
# zca whiten
eig_values, eig_vec = np.linalg.eig(np.cov(patches.T))
zca = eig_vec.dot(np.diag((eig_values+0.01)**-0.5).dot(eig_vec.T))
patches = np.dot(patches, zca)
# k-means
NUM_FILTERS = 64
km = KMeans(n_clusters=NUM_FILTERS, n_jobs=1, random_state=0, n_init=1, verbose=True)
km.fit(patches)
filters = km.cluster_centers_.reshape((NUM_FILTERS,5,5,3))
# display
fig = plt.figure()
num_col = int(np.ceil(float(NUM_FILTERS)/4))
for i in xrange(NUM_FILTERS):
ax = fig.add_subplot(4, num_col, i)
filter_ = filters[i,...]
filter_ -= filter_.min()
filter_ /= filter_.max()
ax.imshow(filter_, interpolation='none')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment