Skip to content

Instantly share code, notes, and snippets.

@gyk
Created July 17, 2014 06:23
Show Gist options
  • Save gyk/d2099469e27c9c467876 to your computer and use it in GitHub Desktop.
Save gyk/d2099469e27c9c467876 to your computer and use it in GitHub Desktop.
Stacked Denoising Auto-Encoder of DeepLearnToolbox (https://github.com/rasmusbergpalm/DeepLearnToolbox) - a simple example
% Playing with the Stacked Denoising Auto-Encoder in the DeepLearnToolbox
% (https://github.com/rasmusbergpalm/DeepLearnToolbox),
% and trying to use it to extract features.
%% Loads data
load mnist_uint8;
% I don't have much memory
clear test_x test_y
% The original MNIST that comes along with DeepLearnToolbox
% contains 60000 samples, which will cause out-of-memory error
% on my old computer.
% Fortunately, the data of 10 numbers are distributed uniformly.
train_x = train_x(1:10000, :);
train_y = train_y(1:10000, :);
train_x = double(train_x) / 255;
train_y = double(train_y);
%% SAE setup & training
rand('twister', 5489); % Matlab's default setting
sae = saesetup([784 200 100]);
nAutoEncoders = numel(sae.ae);
for i = 1:nAutoEncoders
% Why does the author mix camelCase and under_score? :S
sae.ae{i}.activation_function = 'sigm';
sae.ae{i}.learningRate = 1;
sae.ae{i}.inputZeroMaskedFraction = 0.5;
end
opts.numepochs = 1;
opts.batchsize = 100;
sae = saetrain(sae, train_x, opts);
%% Output
% In the `test_example_SAE` demo, the learned weights are feed into an FFNN.
% However here we just use the SAE directly.
fprintf('Please enter the index of data:\n');
while true
index = input('#');
if isempty(index)
break;
end
x = train_x(index, :)';
x0 = x; % backup
% The first half
for i = 1:nAutoEncoders
activ = str2func(sae.ae{i}.activation_function);
x = activ(sae.ae{i}.W{1} * [1; x]);
end
% Now `x` is the extracted feature vector.
% The second half
for i = 1:nAutoEncoders
activ = str2func(sae.ae{i}.activation_function);
x = activ(sae.ae{nAutoEncoders + 1 - i}.W{2} * [1; x]);
end
% Visualizes input/output.
% Theoretically, they should look similar.
fprintf('Number = %d\n', find(train_y(index, :)) - 1);
subplot(1, 2, 1);
imshow(reshape(x0, [28 28])');
subplot(1, 2, 2);
imshow(reshape(x, [28 28])');
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment