Skip to content

Instantly share code, notes, and snippets.

@iceboal
Created March 9, 2014 03:19
Show Gist options
  • Save iceboal/9442482 to your computer and use it in GitHub Desktop.
Save iceboal/9442482 to your computer and use it in GitHub Desktop.
Rewrite Peng Qi's code for 1-D CRBM
function [model, output] = trainCRBM(data, params, oldModel)
% TRAINCRBM Trains a convolutional restricted Boltzmann machine
% with the specified parameters.
%
% [model output] = TRAINCRBM(data, params, oldModel)
%
% data should be a structure, containing:
% data.x The input images / pooling states of the previous layer
% of CRBM. This matrix is 4-D the first three dimensions
% define an image (coloum-stored with a color channel),
% and the last dimension indexes through the batch of
% images. I.e. the four dimensions are: height, width,
% channels (1 for grayscale, 3 for RGB), and number of
% images.
%
% Written by: Peng Qi, Sep 27, 2012
% Last Updated: Feb 8, 2014
% Version: 0.3 alpha
if params.verbose > 0,
fprintf('Starting training CRBM with the following parameters:\n');
disp(params);
fprintf('Initializing parameters...');
end
useCuda = params.useCuda;
if isfield(params, 'method'),
if strcmp(params.method, 'CD'),
method = 1; % Contrastive Divergence
elseif strcmp(params.method, 'PCD'),
%TODO
method = 2; % Persistent Contrastive Divergence
end
else
method = 1; % use Contrastive Divergence as default
end
%% initialization
N = length(data);
Nfilters = params.nmap;
Wfilter = params.szFilter;
p = params.szPool;
H = size(data{1}, 1);
W = cellfun(@size, data, num2cell(2*ones(1,N)));
colors = 1;
Hhidden = H;
Whidden = W - Wfilter + 1;
Hpool = H;
Wpool = floor(Whidden / p);
param_iter = params.iter;
param_szBatch = params.szBatch;
output_enabled = nargout > 1;
%vmasNfilters = conve(ones(nh), ones(m), useCuda);
hinit = 0;
if params.sparseness > 0,
hinit = -.1;
end
if exist('oldModel','var') && ~isempty(oldModel),
model = oldModel;
if (~isfield(model,'W')),
model.W = 0.01 * randn(H, Wfilter, Nfilters);
else
if (size(model.W) ~= [H Wfilter Nfilters]), error('Incompatible input model.'); end
end
if (~isfield(model,'vbias')), model.vbias = 0;end
if (~isfield(model,'hbias')), model.hbias = ones(1, Nfilters) * hinit;end
if (~isfield(model,'sigma')),
if (params.sparseness > 0)
model.sigma = 0.1;
else
model.sigma = 1;
end
end
else
model.W = 0.01 * randn(H, Wfilter, Nfilters);
model.vbias = 0;
model.hbias = ones(1, Nfilters) * hinit;
if (params.sparseness > 0)
model.sigma = 0.1;
else
model.sigma = 1;
end
end
dW = 0;
dvbias = 0;
dhbias = 0;
pW = params.pW;
pvbias = params.pvbias;
phbias = params.phbias;
if output_enabled,
output.x = zeros(Hpool, Wpool, Nfilters, N);
end
total_batches = floor(N / param_szBatch);
if params.verbose > 0,
fprintf('Completed.\n');
end
hidq = params.sparseness;
lambdaq = 0.9;
if ~isfield(model,'iter')
model.iter = 0;
end
%TODO
if (params.whitenData),
try
load(sprintf('whitM_%d', params.szFilter));
catch e,
if (params.verbose > 1), fprintf('\nComputing whitening matrix...');end
compWhitMatrix(data.x, params.szFilter);
load(sprintf('whitM_%d', params.szFilter));
if (params.verbose > 1), fprintf('Completed.\n');end
end
if (params.verbose > 0), fprintf('Whitening data...'); end
data.x = whiten_data(data.x, whM, useCuda);
if (params.verbose > 0), fprintf('Completed.\n'); end
end
if method == 2,
phantom = randn(H, W, colors, N);
end
for iter = model.iter+1:param_iter,
% shuffle data
batch_idx = randperm(N);
if params.verbose > 0,
fprintf('\nIteration %d\n', iter);
if params.verbose > 1,
fprintf('Batch progress (%d total): ', total_batches);
end
end
hidact = zeros(1, Nfilters);
errsum = 0;
if (iter > 5),
params.pW = .9;
params.pvbias = 0;
params.phbias = 0;
end
for batch = 1:total_batches,
batchdata = data(batch_idx((batch - 1) * param_szBatch + 1 : ...
batch * param_szBatch));
if method == 2,
phantomdata = phantom(:,:,:,((batch - 1) * param_szBatch + 1 : ...
batch * param_szBatch));
end
recon = batchdata;
%% positive phase
%% hidden update
model_W = model.W;
model_hbias = model.hbias;
model_vbias = model.vbias;
poshidacts = convs(recon, model_W);
%TODO
[poshidprobs, pospoolprobs, poshidstates] = poolHidden(...
cellfun(@mrdivide, poshidacts, num2cell(model.sigma*ones(1,param_szBatch)), 'UniformOutput', 0), ...
model_hbias / model.sigma, p);
if output_enabled && ~rem(iter, params.saveInterv),
output_x = pospoolprobs;
end
if output_enabled && ~rem(iter, params.saveInterv),
output.x(:,:,:,batch_idx((batch - 1) * param_szBatch + 1 : ...
batch * param_szBatch)) = output_x;
end
%% negative phase
%% reconstruct data from hidden variables
if method == 1,
recon = conve(poshidstates, model_W(:,end:-1:1,:));
elseif method == 2,
recon = phantomdata;
end
recon = cellfun(@plus, recon, num2cell(model_vbias*ones(1,param_szBatch)),'UniformOutput', 0);
if (params.sparseness > 0),
recon = recon + model.sigma * randn(size(recon));
end
% need add gausian here?
%% mean field hidden update
neghidacts = convs(recon, model_W);
neghidprobs = poolHidden(...
cellfun(@mrdivide, neghidacts, num2cell(model.sigma*ones(1,param_szBatch)), 'UniformOutput', 0),...
model_hbias / model.sigma, p);
if (params.verbose > 1),
fprintf('.');
err = cellfun(@minus, batchdata, recon, 'UniformOutput', 0);
errsum = errsum + sum(sum(cat(2,err{:}).^2));
if (params.verbose > 4),
%% visualize data, reconstruction, and filters (still experimental)
figure(1);
for i = 1:16,subplot(4,8,i+16);imagesc(model.W(:,:,:,i));axis image off;end;colormap gray;drawnow;
subplot(2,2,1);imagesc(batchdata(:,:,1));colormap gray;axis off;title('data');
subplot(2,2,2);imagesc(recon(:,:,1));colormap gray;axis off;title('reconstruction');
drawnow;
end
end
%% contrast divergence update on params
if (params.sparseness > 0),
hidact = hidact + reshape(sum(sum(sum(pospoolprobs, 4), 2), 1), [1 Nfilters]);
else
kmean = cellfun(@mean, cellfun(@mean, cellfun(@minus, poshidprobs, neghidprobs, 'UniformOutput', 0), ...
num2cell(ones(1, param_szBatch)), 'UniformOutput', 0), num2cell(2*ones(1,param_szBatch)), 'UniformOutput', 0);
dhbias = phbias * dhbias + mean(cat(2, kmean{:}));
% reshape((sum(sum(sum(poshidprobs, 4), 2), 1) - sum(sum(sum(neghidprobs, 4), 2), 1))...
% / Whidden / Hhidden / param_szBatch, [1 Nfilters]);
end
kmean = cellfun(@mean, cellfun(@mean, cellfun(@minus, batchdata, recon, 'UniformOutput', 0), ...
num2cell(ones(1, param_szBatch)), 'UniformOutput', 0), num2cell(ones(1,param_szBatch)), 'UniformOutput', 0);
dvbias = pvbias * dvbias + mean(cell2mat(kmean));
%TODO
%ddw = convs4(batchdata(Wfilter:H-Wfilter+1,Wfilter:W-Wfilter+1,:).x, poshidprobs(Wfilter:Hhidden-Wfilter+1,Wfilter:Whidden-Wfilter+1,:,:)) ...
% - convs4( recon(Wfilter:H-Wfilter+1,Wfilter:W-Wfilter+1,:).x, neghidprobs(Wfilter:Hhidden-Wfilter+1,Wfilter:Whidden-Wfilter+1,:,:));
ddw = convs4(batchdata, poshidprobs) ...
- convs4(recon, neghidprobs);
dW = pW * dW + ddw;%(Whidden - 2 * Wfilter + 2) / param_szBatch;
model.vbias = model.vbias + params.epsvbias * dvbias;
if params.sparseness <= 0,
model.hbias = model.hbias + reshape(params.epshbias * dhbias,[1 Nfilters]);
end
model.W = model.W + params.epsW * (dW - params.decayw * model.W);
%% experimental code for saving debugging info for mex implementations
% save dbgInfo model poshidacts poshidprobs poshidstates recon neghidacts neghidprobs model_W
% if any(isnan(model.W(:))) || any(isnan(poshidacts(:))) || any(isnan(poshidprobs(:))) || any(isnan(poshidstates(:))) ...
% || any(isnan(recon(:))) || any(isnan(neghidacts(:))) || any(isnan(neghidprobs(:))),
% return;
% end
if method == 2,
phantom(:,:,:,batch_idx((batch - 1) * param_szBatch + 1 : ...
batch * param_szBatch)) = conve(neghidprobs, model_W, useCuda);
end
end
if (params.verbose > 1),
fprintf('\n\terror:%f', errsum);
end
if params.sparseness > 0,
hidact = hidact / Hhidden / Whidden / N;
hidq = hidq * lambdaq + hidact * (1 - lambdaq);
dhbias = phbias * dhbias + ((params.sparseness) - (hidq));
model.hbias = model.hbias + params.epshbias * dhbias;
if params.verbose > 0,
if (params.verbose > 1),
fprintf('\tsigma:%f', model.sigma);
end
fprintf('\n\tsparseness: %f\thidbias: %f\n', sum(hidact) / Nfilters, sum(model.hbias) / Nfilters);
end
if (model.sigma > 0.01),
model.sigma = model.sigma * 0.95;
end
end
if ~rem(iter, params.saveInterv),
if (params.verbose > 3),
figure(1);
for i = 1:16,subplot(4,8,i+16);imagesc(model.W(:,:,:,i));axis image off;end;colormap gray;drawnow;
subplot(2,2,1);imagesc(batchdata(:,:,1));colormap gray;axis off;title('data');
subplot(2,2,2);imagesc(recon(:,:,1));colormap gray;axis off;title('reconstruction');
drawnow;
end
if output_enabled,
model.iter = iter;
save(params.saveName, 'model', 'output', 'iter');
if params.verbose > 1,
fprintf('\nModel and output saved at iteration %d\n', iter);
end
else
model.iter = iter;
save(params.saveName, 'model', 'iter');
if params.verbose > 1,
fprintf('\nModel saved at iteration %d\n', iter);
end
end
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment