Animate Denoising AutoEncoder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
This code uses "dA" class defined in: | |
http://www.deeplearning.net/tutorial/dA.html | |
""" | |
import os | |
import numpy | |
import theano | |
import theano.tensor as T | |
from theano.tensor.shared_randomstreams import RandomStreams | |
from logistic_sgd import load_data | |
from utils import tile_raster_images | |
from dA import dA | |
def test_dA(learning_rate=0.1, training_epochs=20, | |
dataset='mnist.pkl.gz', | |
batch_size=20, output_folder='dA_plots'): | |
datasets = load_data(dataset) | |
train_set_x, train_set_y = datasets[0] | |
# compute number of minibatches for training, validation and testing | |
n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size | |
# allocate symbolic variables for the data | |
index = T.lscalar() # index to a [mini]batch | |
x = T.matrix('x') # the data is presented as rasterized images | |
if not os.path.isdir(output_folder): | |
os.makedirs(output_folder) | |
os.chdir(output_folder) | |
#################################### | |
# BUILDING THE MODEL NO CORRUPTION # | |
#################################### | |
rng = numpy.random.RandomState(123) | |
theano_rng = RandomStreams(rng.randint(2 ** 30)) | |
def get_dA(): | |
return dA(numpy_rng=rng, theano_rng=theano_rng, input=x, | |
n_visible=28 * 28, n_hidden=500) | |
da0 = get_dA() | |
da30 = get_dA() | |
def build_train_func(da, corruption_level): | |
cost, updates = da.get_cost_updates( | |
corruption_level=corruption_level, | |
learning_rate=learning_rate | |
) | |
train_da = theano.function([index], cost, updates=updates, | |
givens={x: train_set_x[index * batch_size: (index + 1) * batch_size]}) | |
return train_da | |
train_da0 = build_train_func(da0, 0.) | |
train_da30 = build_train_func(da30, 0.3) | |
def get_image(da): | |
return tile_raster_images(X=da.W.get_value(borrow=True).T, | |
img_shape=(28, 28), tile_shape=(10, 10), | |
tile_spacing=(1, 1)) | |
import matplotlib.pyplot as plt | |
gray = plt.get_cmap('gray') | |
for epoch in xrange(training_epochs): | |
c0 = [] | |
c30 = [] | |
for batch_index in xrange(n_train_batches): | |
c0.append(train_da0(batch_index)) | |
c30.append(train_da30(batch_index)) | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 4)) | |
fig.suptitle('Training epoch {0}'.format(epoch)) | |
ax1.imshow(get_image(da0), cmap=gray) | |
ax1.set_title('Corruption rate: 0%, Cost: {0:.2f}'.format(numpy.mean(c0)), size=9) | |
ax2.imshow(get_image(da30), cmap=gray) | |
ax2.set_title('Corruption rate: 30%, Cost: {0:.2f}'.format(numpy.mean(c30)), size=9) | |
for ax in [ax1, ax2]: | |
ax.xaxis.set_visible(False) | |
ax.yaxis.set_visible(False) | |
plt.show() | |
# plt.savefig('da_{0:02d}.png'.format(epoch)) | |
if __name__ == '__main__': | |
test_dA() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment