Skip to content

Instantly share code, notes, and snippets.

@josharian
Created December 28, 2011 01:44
Show Gist options
  • Save josharian/1525765 to your computer and use it in GitHub Desktop.
Save josharian/1525765 to your computer and use it in GitHub Desktop.
from time import time
import pylab as pl
import scipy as sp
import numpy as np
from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import PatchExtractor
from sklearn.feature_extraction.image import reconstruct_from_patches_2d
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Scaler
###############################################################################
# Load Lena image and extract patches
lena = sp.lena() / 256.0
# downsample for higher speed
lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2]
lena /= 4.0
height, width = lena.shape
# Distort the right half of the image
print 'Distorting image...'
distorted = lena.copy()
distorted[:, height / 2:] += 0.075 * np.random.randn(width, height / 2)
# Extract all clean patches from the left half of the image
print 'Extracting clean patches...'
t0 = time()
patch_size = (7, 7)
pipeline = Pipeline([("extract", PatchExtractor(patch_size)), ("scale", Scaler()), ("sparse", MiniBatchDictionaryLearning(n_atoms=100, alpha=1e-2, n_iter=500))])
data = distorted[:, :height / 2]
print "Data shape", data.shape
RESHAPE = False # Fails for either value of RESHAPE
if RESHAPE:
data.shape = (1, data.shape[0], data.shape[1])
print "Data reshaped", data.shape
###############################################################################
# Learn the dictionary from clean patches
print 'Learning the dictionary... '
t0 = time()
V = pipeline.fit(data).components_
print "V shape", V.shape
dt = time() - t0
print 'done in %.2fs.' % dt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment