Created
March 9, 2014 03:19
-
-
Save iceboal/9442482 to your computer and use it in GitHub Desktop.
Rewrite Peng Qi's code for 1-D CRBM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [model, output] = trainCRBM(data, params, oldModel) | |
% TRAINCRBM Trains a convolutional restricted Boltzmann machine | |
% with the specified parameters. | |
% | |
% [model output] = TRAINCRBM(data, params, oldModel) | |
% | |
% data should be a structure, containing: | |
% data.x The input images / pooling states of the previous layer | |
% of CRBM. This matrix is 4-D the first three dimensions | |
% define an image (coloum-stored with a color channel), | |
% and the last dimension indexes through the batch of | |
% images. I.e. the four dimensions are: height, width, | |
% channels (1 for grayscale, 3 for RGB), and number of | |
% images. | |
% | |
% Written by: Peng Qi, Sep 27, 2012 | |
% Last Updated: Feb 8, 2014 | |
% Version: 0.3 alpha | |
if params.verbose > 0, | |
fprintf('Starting training CRBM with the following parameters:\n'); | |
disp(params); | |
fprintf('Initializing parameters...'); | |
end | |
useCuda = params.useCuda; | |
if isfield(params, 'method'), | |
if strcmp(params.method, 'CD'), | |
method = 1; % Contrastive Divergence | |
elseif strcmp(params.method, 'PCD'), | |
%TODO | |
method = 2; % Persistent Contrastive Divergence | |
end | |
else | |
method = 1; % use Contrastive Divergence as default | |
end | |
%% initialization | |
N = length(data); | |
Nfilters = params.nmap; | |
Wfilter = params.szFilter; | |
p = params.szPool; | |
H = size(data{1}, 1); | |
W = cellfun(@size, data, num2cell(2*ones(1,N))); | |
colors = 1; | |
Hhidden = H; | |
Whidden = W - Wfilter + 1; | |
Hpool = H; | |
Wpool = floor(Whidden / p); | |
param_iter = params.iter; | |
param_szBatch = params.szBatch; | |
output_enabled = nargout > 1; | |
%vmasNfilters = conve(ones(nh), ones(m), useCuda); | |
hinit = 0; | |
if params.sparseness > 0, | |
hinit = -.1; | |
end | |
if exist('oldModel','var') && ~isempty(oldModel), | |
model = oldModel; | |
if (~isfield(model,'W')), | |
model.W = 0.01 * randn(H, Wfilter, Nfilters); | |
else | |
if (size(model.W) ~= [H Wfilter Nfilters]), error('Incompatible input model.'); end | |
end | |
if (~isfield(model,'vbias')), model.vbias = 0;end | |
if (~isfield(model,'hbias')), model.hbias = ones(1, Nfilters) * hinit;end | |
if (~isfield(model,'sigma')), | |
if (params.sparseness > 0) | |
model.sigma = 0.1; | |
else | |
model.sigma = 1; | |
end | |
end | |
else | |
model.W = 0.01 * randn(H, Wfilter, Nfilters); | |
model.vbias = 0; | |
model.hbias = ones(1, Nfilters) * hinit; | |
if (params.sparseness > 0) | |
model.sigma = 0.1; | |
else | |
model.sigma = 1; | |
end | |
end | |
dW = 0; | |
dvbias = 0; | |
dhbias = 0; | |
pW = params.pW; | |
pvbias = params.pvbias; | |
phbias = params.phbias; | |
if output_enabled, | |
output.x = zeros(Hpool, Wpool, Nfilters, N); | |
end | |
total_batches = floor(N / param_szBatch); | |
if params.verbose > 0, | |
fprintf('Completed.\n'); | |
end | |
hidq = params.sparseness; | |
lambdaq = 0.9; | |
if ~isfield(model,'iter') | |
model.iter = 0; | |
end | |
%TODO | |
if (params.whitenData), | |
try | |
load(sprintf('whitM_%d', params.szFilter)); | |
catch e, | |
if (params.verbose > 1), fprintf('\nComputing whitening matrix...');end | |
compWhitMatrix(data.x, params.szFilter); | |
load(sprintf('whitM_%d', params.szFilter)); | |
if (params.verbose > 1), fprintf('Completed.\n');end | |
end | |
if (params.verbose > 0), fprintf('Whitening data...'); end | |
data.x = whiten_data(data.x, whM, useCuda); | |
if (params.verbose > 0), fprintf('Completed.\n'); end | |
end | |
if method == 2, | |
phantom = randn(H, W, colors, N); | |
end | |
for iter = model.iter+1:param_iter, | |
% shuffle data | |
batch_idx = randperm(N); | |
if params.verbose > 0, | |
fprintf('\nIteration %d\n', iter); | |
if params.verbose > 1, | |
fprintf('Batch progress (%d total): ', total_batches); | |
end | |
end | |
hidact = zeros(1, Nfilters); | |
errsum = 0; | |
if (iter > 5), | |
params.pW = .9; | |
params.pvbias = 0; | |
params.phbias = 0; | |
end | |
for batch = 1:total_batches, | |
batchdata = data(batch_idx((batch - 1) * param_szBatch + 1 : ... | |
batch * param_szBatch)); | |
if method == 2, | |
phantomdata = phantom(:,:,:,((batch - 1) * param_szBatch + 1 : ... | |
batch * param_szBatch)); | |
end | |
recon = batchdata; | |
%% positive phase | |
%% hidden update | |
model_W = model.W; | |
model_hbias = model.hbias; | |
model_vbias = model.vbias; | |
poshidacts = convs(recon, model_W); | |
%TODO | |
[poshidprobs, pospoolprobs, poshidstates] = poolHidden(... | |
cellfun(@mrdivide, poshidacts, num2cell(model.sigma*ones(1,param_szBatch)), 'UniformOutput', 0), ... | |
model_hbias / model.sigma, p); | |
if output_enabled && ~rem(iter, params.saveInterv), | |
output_x = pospoolprobs; | |
end | |
if output_enabled && ~rem(iter, params.saveInterv), | |
output.x(:,:,:,batch_idx((batch - 1) * param_szBatch + 1 : ... | |
batch * param_szBatch)) = output_x; | |
end | |
%% negative phase | |
%% reconstruct data from hidden variables | |
if method == 1, | |
recon = conve(poshidstates, model_W(:,end:-1:1,:)); | |
elseif method == 2, | |
recon = phantomdata; | |
end | |
recon = cellfun(@plus, recon, num2cell(model_vbias*ones(1,param_szBatch)),'UniformOutput', 0); | |
if (params.sparseness > 0), | |
recon = recon + model.sigma * randn(size(recon)); | |
end | |
% need add gausian here? | |
%% mean field hidden update | |
neghidacts = convs(recon, model_W); | |
neghidprobs = poolHidden(... | |
cellfun(@mrdivide, neghidacts, num2cell(model.sigma*ones(1,param_szBatch)), 'UniformOutput', 0),... | |
model_hbias / model.sigma, p); | |
if (params.verbose > 1), | |
fprintf('.'); | |
err = cellfun(@minus, batchdata, recon, 'UniformOutput', 0); | |
errsum = errsum + sum(sum(cat(2,err{:}).^2)); | |
if (params.verbose > 4), | |
%% visualize data, reconstruction, and filters (still experimental) | |
figure(1); | |
for i = 1:16,subplot(4,8,i+16);imagesc(model.W(:,:,:,i));axis image off;end;colormap gray;drawnow; | |
subplot(2,2,1);imagesc(batchdata(:,:,1));colormap gray;axis off;title('data'); | |
subplot(2,2,2);imagesc(recon(:,:,1));colormap gray;axis off;title('reconstruction'); | |
drawnow; | |
end | |
end | |
%% contrast divergence update on params | |
if (params.sparseness > 0), | |
hidact = hidact + reshape(sum(sum(sum(pospoolprobs, 4), 2), 1), [1 Nfilters]); | |
else | |
kmean = cellfun(@mean, cellfun(@mean, cellfun(@minus, poshidprobs, neghidprobs, 'UniformOutput', 0), ... | |
num2cell(ones(1, param_szBatch)), 'UniformOutput', 0), num2cell(2*ones(1,param_szBatch)), 'UniformOutput', 0); | |
dhbias = phbias * dhbias + mean(cat(2, kmean{:})); | |
% reshape((sum(sum(sum(poshidprobs, 4), 2), 1) - sum(sum(sum(neghidprobs, 4), 2), 1))... | |
% / Whidden / Hhidden / param_szBatch, [1 Nfilters]); | |
end | |
kmean = cellfun(@mean, cellfun(@mean, cellfun(@minus, batchdata, recon, 'UniformOutput', 0), ... | |
num2cell(ones(1, param_szBatch)), 'UniformOutput', 0), num2cell(ones(1,param_szBatch)), 'UniformOutput', 0); | |
dvbias = pvbias * dvbias + mean(cell2mat(kmean)); | |
%TODO | |
%ddw = convs4(batchdata(Wfilter:H-Wfilter+1,Wfilter:W-Wfilter+1,:).x, poshidprobs(Wfilter:Hhidden-Wfilter+1,Wfilter:Whidden-Wfilter+1,:,:)) ... | |
% - convs4( recon(Wfilter:H-Wfilter+1,Wfilter:W-Wfilter+1,:).x, neghidprobs(Wfilter:Hhidden-Wfilter+1,Wfilter:Whidden-Wfilter+1,:,:)); | |
ddw = convs4(batchdata, poshidprobs) ... | |
- convs4(recon, neghidprobs); | |
dW = pW * dW + ddw;%(Whidden - 2 * Wfilter + 2) / param_szBatch; | |
model.vbias = model.vbias + params.epsvbias * dvbias; | |
if params.sparseness <= 0, | |
model.hbias = model.hbias + reshape(params.epshbias * dhbias,[1 Nfilters]); | |
end | |
model.W = model.W + params.epsW * (dW - params.decayw * model.W); | |
%% experimental code for saving debugging info for mex implementations | |
% save dbgInfo model poshidacts poshidprobs poshidstates recon neghidacts neghidprobs model_W | |
% if any(isnan(model.W(:))) || any(isnan(poshidacts(:))) || any(isnan(poshidprobs(:))) || any(isnan(poshidstates(:))) ... | |
% || any(isnan(recon(:))) || any(isnan(neghidacts(:))) || any(isnan(neghidprobs(:))), | |
% return; | |
% end | |
if method == 2, | |
phantom(:,:,:,batch_idx((batch - 1) * param_szBatch + 1 : ... | |
batch * param_szBatch)) = conve(neghidprobs, model_W, useCuda); | |
end | |
end | |
if (params.verbose > 1), | |
fprintf('\n\terror:%f', errsum); | |
end | |
if params.sparseness > 0, | |
hidact = hidact / Hhidden / Whidden / N; | |
hidq = hidq * lambdaq + hidact * (1 - lambdaq); | |
dhbias = phbias * dhbias + ((params.sparseness) - (hidq)); | |
model.hbias = model.hbias + params.epshbias * dhbias; | |
if params.verbose > 0, | |
if (params.verbose > 1), | |
fprintf('\tsigma:%f', model.sigma); | |
end | |
fprintf('\n\tsparseness: %f\thidbias: %f\n', sum(hidact) / Nfilters, sum(model.hbias) / Nfilters); | |
end | |
if (model.sigma > 0.01), | |
model.sigma = model.sigma * 0.95; | |
end | |
end | |
if ~rem(iter, params.saveInterv), | |
if (params.verbose > 3), | |
figure(1); | |
for i = 1:16,subplot(4,8,i+16);imagesc(model.W(:,:,:,i));axis image off;end;colormap gray;drawnow; | |
subplot(2,2,1);imagesc(batchdata(:,:,1));colormap gray;axis off;title('data'); | |
subplot(2,2,2);imagesc(recon(:,:,1));colormap gray;axis off;title('reconstruction'); | |
drawnow; | |
end | |
if output_enabled, | |
model.iter = iter; | |
save(params.saveName, 'model', 'output', 'iter'); | |
if params.verbose > 1, | |
fprintf('\nModel and output saved at iteration %d\n', iter); | |
end | |
else | |
model.iter = iter; | |
save(params.saveName, 'model', 'iter'); | |
if params.verbose > 1, | |
fprintf('\nModel saved at iteration %d\n', iter); | |
end | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment