Skip to content

Instantly share code, notes, and snippets.

@eickenberg
Created April 4, 2015 16:45
Show Gist options
  • Save eickenberg/561b5c3b8eec315dde07 to your computer and use it in GitHub Desktop.
Save eickenberg/561b5c3b8eec315dde07 to your computer and use it in GitHub Desktop.
Draft of a convolutional ZCA
# Disclaimer: This doesn't seem to work 100% yet, but almost ;)
# Convolutional ZCA
# When images are too large in amount of pixels to be able to determine the
# principal components of an image batch, one can suppose translation
# invariance of the eigen-structure and do the ZCA in a convolutional manner
import theano
import theano.tensor as T
import numpy as np
from sklearn.feature_extraction.image import extract_patches
from sklearn.decomposition import PCA
def convolutional_zca(X, patch_size=(8, 8), step_size=(2, 2)):
"""Perform convolutional ZCA
Parameters
==========
X: ndarray, shape= (height, width, n_channels)
"""
n_imgs, h, w, n_channels = X.shape
if len(patch_size) == 2:
patch_size = patch_size + (n_channels,)
if len(step_size) == 2:
step_size = step_size + (1,)
patches = extract_patches(X,
(1,) + patch_size,
(1,) + step_size).reshape((-1,) + patch_size)
pca = PCA()
pca.fit(patches.reshape(patches.shape[0], -1))
# Transpose the components into theano convolution filter type
components = theano.shared(pca.components_.reshape(
(-1,) + patch_size).transpose(0, 3, 1, 2).astype(X.dtype))
whitening_factors = T.addbroadcast(
theano.shared(1. / np.sqrt(pca.explained_variance_).astype(X.dtype).reshape((1, -1, 1, 1))), 0, 2, 3)
componentsT = components.dimshuffle((1, 0, 2, 3))[:, :, ::-1, ::-1]
input_images = T.tensor4(dtype=X.dtype)
conv_whitening = T.nnet.conv2d(
T.nnet.conv2d(input_images.dimshuffle((0, 3, 1, 2)),
components, border_mode='full') * whitening_factors,
componentsT)
f_whitening = theano.function([input_images], conv_whitening)
return f_whitening(X)
if __name__ == "__main__":
CIFAR_DIR = '/home/me/data/datasets/cifar-10-batches-py'
CIFAR_FILE = 'data_batch_2'
n_images = 1000
import os
import pickle
cifar = pickle.load(open(os.path.join(CIFAR_DIR, CIFAR_FILE)))
images = cifar['data'][:n_images].reshape(
-1, 3, 32, 32).transpose(0, 2, 3, 1)
whitened = convolutional_zca(images.astype(np.float32))
@TronicLT
Copy link

TronicLT commented Feb 4, 2016

Check David Eigen's thesis for an algorithmic description of Convolutional ZCA Whitening (section 8.4).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment