Skip to content

Instantly share code, notes, and snippets.

@rebordao
Created October 23, 2020 10:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rebordao/e30d0fbc98a40606613fba95d7412cfa to your computer and use it in GitHub Desktop.
Save rebordao/e30d0fbc98a40606613fba95d7412cfa to your computer and use it in GitHub Desktop.
Trains a network by backpropagation through time or by CTC
function [net hid_without_bias output] = trains_nn(net, data, targets, phonemes_id);
% This function trains a network by backpropagation, by backpropagation
% through time or by connectionist temporal classification.
%
% INPUTS:
% net contains the topology and the user-defined training parameters;
% data is the train data
% targets are the real targets
% phonemes_id is the indices of the phonemes in the dictionary file
%
% OUTPUTS:
% net contains the topology and the gradient used for updating the weights
% hid_without_bias is the output of the network's hidden units
% output is the net's output
%
% Antonio Rebordao, 2011
% ----------------------------------------------------------------------
nr_obser = size(data, 2);
if ~strcmp(net.training_type, 'bp') && ~strcmp(net.training_type, 'bptt') && ~strcmp(net.training_type, 'ctc')
error('The training type needs to be bp, bptt or ctc.')
end
% loads the weights
W_inp = net.W_inp;
W_rec = net.W_rec;
W_out = net.W_out;
% FORWARD PASS
[hid_without_bias output] = forward_pass(net, data);
if strcmp(net.training_type, 'bp')
%% BP TRAINING
% BACKWARD PASS
delta_out = output - targets;
delta_hid = 1 - hid_without_bias.^2;
delta_hid = delta_hid .* (W_out(:, 1:end-1)' * delta_out);
% computes/stores the gradient
d_W_out = delta_out * [hid_without_bias; ones(1, nr_obser)]';
d_W_inp = delta_hid * [data; ones(1, nr_obser)]';
elseif strcmp(net.training_type, 'bptt')
%% BPTT TRAINING
% BACKWARD PASS
delta_out = output - targets;
delta_hid = W_out(:, 1:end-1)' * delta_out;
delta_hid(:, end) = (1 - hid_without_bias(:, end).^2) .* delta_hid(:, end);
for steps = nr_obser-1:-1:1
delta_hid(:, steps) = (1 - hid_without_bias(:, steps).^2) .* (delta_hid(:, steps) + (W_rec' * delta_hid(:, steps + 1)));
end
% computes/stores the gradient
d_W_out = delta_out * [hid_without_bias; ones(1, nr_obser)]';
d_W_rec = delta_hid(:, 2:nr_obser) * hid_without_bias(:, 1:end-1)';
d_W_inp = delta_hid * [data; ones(1, nr_obser)]';
elseif strcmp(net.training_type, 'ctc')
%% CTC training with BPTT
phonemes_id_modif = ones(2 * length(phonemes_id) + 1, 1);
for i = 1:length(phonemes_id)
phonemes_id_modif(2*i) = phonemes_id(i) + 1; % 1 is reserved for blanks so we swift all the others by one unit
end
% backward/forward algorithm
[alpha beta loglik] = ctc_fw_bw(output, phonemes_id_modif);
net.loglik = loglik;
% p(l/x)
gamma = alpha .* beta;
for s = 1:size(gamma, 1)
gamma(s,:) = gamma(s,:);
end
sum_gamma = sum(gamma, 1);
% computes the estimated signal
alphabet = [0 1:39]';
nr_classes = length(alphabet);
pos = cell(1, nr_classes);
lab = false(nr_classes, length(phonemes_id_modif));
est = zeros(nr_classes, nr_obser);
for k = 1:nr_classes
for m = 1:length(phonemes_id_modif)
if alphabet(k) == phonemes_id_modif(m) - 1
lab(k, m) = true; % identifies the points of the input sequence modified where label k occurs
end
end
pos{k} = find(lab(k,:));
if ~isempty(pos{k})
aux = zeros(length(pos{k}), nr_obser);
for m = 1:length(pos{k})
aux(m,:) = gamma(pos{k}(m),:);
end
est(k,:) = sum(aux, 1) ./ sum_gamma;
end
end
% BACKWARD PASS and gradient computation
delta_out = output - est;
d_W_out = delta_out * [hid_without_bias; ones(1, nr_obser)]';
if strcmp(net.architecture, 'reservoir')
net.grad = d_W_out(:)';
else
delta_hid = W_out(:, 1:end-1)' * delta_out;
delta_hid(:, end) = (1 - hid_without_bias(:, end).^2) .* delta_hid(:, end);
for steps = (nr_obser-1):-1:1
delta_hid(:, steps) = (1 - hid_without_bias(:, steps).^2) .* (delta_hid(:, steps) + (W_rec' * delta_hid(:, steps + 1)));
end
d_W_rec = delta_hid(:, 2:nr_obser) * hid_without_bias(:, 1:end-1)';
d_W_inp = delta_hid * [data; ones(1, nr_obser)]';
end
% some plots
if ~mod(net.count, 10 * net.batch_size) && net.verbose == 1 % displays plots every x samples if verbose = 1
subplot(511)
plot(alpha')
title('alphas')
subplot(512)
plot(beta')
title('betas')
subplot(513)
plot(output')
axis([0 Inf 0 1])
title('net''s output')
subplot(514)
plot(est')
axis([0 Inf 0 1])
title('prior signal used to train the network')
subplot(515)
plot(delta_out')
title('net''s error')
drawnow
end
end
if ~strcmp(net.training_type, 'bp')
% truncates gradients up to the length of the utterances
% d_W_out = d_W_out ./ nr_obser;
%if ~strcmp(net.architecture, 'reservoir')
%d_W_rec = d_W_rec ./ nr_obser;
%d_W_inp = d_W_inp ./ nr_obser;
% stores gradients
%net.grad = [d_W_inp(:)' d_W_rec(:)' d_W_out(:)'];
%end
else
% truncates gradients up to the length of the utterances
d_W_out = d_W_out ./ nr_obser;
d_W_inp = d_W_inp ./ nr_obser;
% stores gradients
net.grad = [d_W_inp(:)' d_W_out(:)'];
end
% WEIGHT UPDATE
net.W_out = W_out - net.learning_rate .* d_W_out - net.learning_rate * net.regularization_term * W_out + net.mom .* (W_out - net.W_out_old);
net.W_out_old = W_out;
if strcmp(net.architecture, 'ffnn') || strcmp(net.architecture, 'rnn')
if strcmp(net.architecture, 'rnn')
net.W_rec = W_rec - net.learning_rate .* d_W_rec - net.learning_rate * net.regularization_term * W_rec + net.mom .* (W_rec - net.W_rec_old);
net.W_rec_old = W_rec;
end
net.W_inp = W_inp - net.learning_rate .* d_W_inp - net.learning_rate * net.regularization_term * W_inp + net.mom .* (W_inp - net.W_inp_old);
net.W_inp_old = W_inp;
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment