Last active
December 23, 2021 23:17
-
-
Save XinyueZ/eae3e9b813f6b5d55f0d85f9c386bb96 to your computer and use it in GitHub Desktop.
The numerical gradient check
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
disp("The numerical gradient check......"); | |
function W = randWeights(in, out) | |
epsilon_init = sqrt(6) / sqrt(in + out); | |
W = rand(in, out) * 2 * epsilon_init - epsilon_init; | |
endfunction | |
function numericalGrads = applyNumericalGradients(costFunc, nn_params) | |
% Apply partial derivative of costFunc with respect to the | |
% i-th input argument. | |
% All parameters flattened in the network. | |
% % % % % % % % % % % % % % % % % % | |
% nn_params: | |
% 1 | |
% 2 | |
% 3 | |
% % % % % % % % % % % % % % % % % % | |
m = size(nn_params, 1); | |
mask = eye(m, m); | |
e = 1e-4; | |
nn_params = repmat(nn_params, 1, m); % Repeat parameters to columns. | |
% % % % % % % % % % % % % % % % % % | |
% Prepare deviation for each input | |
% e 0 0 | |
% 0 e 0 | |
% 0 0 e | |
% % % % % % % % % % % % % % % % % % | |
deviation = mask .* e; | |
iter = 1:m; | |
% % % % % % % % % % % % % % % % % % % % | |
% Perform deviation on each position. | |
% nn_params pepeated -/+ deviation: | |
% 1 1 1 e 0 0 | |
% 2 2 2 -/+ 0 e 0 -> diff1, diff2 | |
% 3 3 3 0 0 e | |
% | |
% % % % % % % % % % % % % % % % % % % % | |
diff1 = nn_params - deviation; | |
% Perfom the partial derivative respect to each input. | |
loss1 = arrayfun(@(i)costFunc(diff1(:, i)), iter, 'UniformOutput', false); % https://stackoverflow.com/a/10898162/1835650 | |
% % % % % % % % % % % % % % % % % % % % % % % % % | |
% loss is 1xm vector: [loss#1, loss#2, loss#3] | |
% Return to matrix: | |
% loss#1 0 0 | |
% 0 loss#2 0 | |
% 0 0 loss#3 | |
% % % % % % % % % % % % % % % % % % % % % % % % % | |
loss1 = [loss1{:}] .* mask; | |
diff2 = nn_params + deviation; | |
loss2 = arrayfun(@(i)costFunc(diff2(:, i)), iter, 'UniformOutput', false); | |
loss2 = [loss2{:}] .* mask; | |
% % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % | |
% Return the list of partial derivative respect to each input. | |
% loss#1 0 0 loss#1 | |
% 0 loss#2 0 -> loss#2 | |
% 0 0 loss#3 loss#3 | |
% % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % | |
numericalGrads = sum((loss2 - loss1), 2); | |
numericalGrads = numericalGrads ./ (2 * e); | |
endfunction | |
function [hX, loss, X] = hypothesis(X, Y, m, n, units, nn_params) | |
nn_params = reshape(nn_params(1:units * (n + 1)), units, (n + 1)); | |
X = [ones(m, 1) X]; % Add "1" column to consume bias. | |
hX = X * nn_params'; | |
loss = hX - Y; | |
endfunction | |
function cost = cost(X, Y, m, n, units, nn_params) | |
[~, loss] = hypothesis(X, Y, m, n, units, nn_params); | |
cost = (1 / (2 * m)) * sum(sum(loss.^2, 2), 1); | |
endfunction | |
function grads = gradients(X, Y, m, n, units, nn_params) | |
[~, loss, X] = hypothesis(X, Y, m, n, units, nn_params); | |
grads = (1 / m) * (X' * loss); | |
grads = [grads(:)]; %Flatten all layers' gradients. | |
end | |
function nn() | |
m = 100; % m rows data | |
n = 5; % input features numbers | |
X = randWeights(m, n); % Fake input | |
Y = mode(1:m, randi(100, 1, 1))' + randi(100, 1, 1); % Fake output | |
assert(size(X, 1) == size(Y, 1)); | |
nn_params = randWeights(1, n + 1); % Ouput only one unit, +1 of the bias. | |
nn_params = [nn_params(:)]; % Flatten parameters | |
grads = gradients(X, Y, m, n, 1, nn_params); % Apply one round gradient descent. | |
costFunc = @(params)cost(X, Y, m, n, 1, params); % Pre-apply cost calculation with input and ground-truth | |
numericalGrads = applyNumericalGradients(costFunc, nn_params); % Apply gradient descent with respect to each input of nn_params manully. | |
disp([grads numericalGrads]); | |
diff = norm(numericalGrads - grads) / norm(numericalGrads + grads); | |
checkSign = "𐄂"; | |
if (diff < 1e-9) | |
checkSign = "✓"; | |
endif | |
printf(['\nGradients check difference(ok if less than 1e-9).: %g (%s)\n'], diff, checkSign); | |
endfunction | |
nn(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment