Skip to content

Instantly share code, notes, and snippets.

@ka9e
Created January 4, 2015 09:05
Show Gist options
  • Save ka9e/456c86fac717048a9c0d to your computer and use it in GitHub Desktop.
Save ka9e/456c86fac717048a9c0d to your computer and use it in GitHub Desktop.
curve fitting
%%% reference :
%%% (1) パターン認識と機械学習
%%% (2) http://aidiary.hatenablog.com/entries/2014/01/22
clear all;
N = 100;
EPS = 0.01;
ETA = 0.1;
LOOP = 500;
D = 2; % in
M = 4; % hidden
K = 1; % out
X = linspace(-5, 5, N);
bias = ones(size(X));
% T = sin(X);
T = sin(X) + 0.25 * randn(size(X));
w1 = randn(M, D);
w2 = randn(K, M);
function error = sum_sq_error(x, t, w1, w2)
error = 0.0;
z = tanh(w1 * x);
y = w2 * z;
error += sum((y - t).^2) / 2;
end
xs = vertcat(ones(size(X)), X);
errs = zeros(1, LOOP);
c = 0;
%err1 = 0;
%err2 = Inf;
for _ = 1:LOOP
%while abs(err2 - err1) > EPS
for n = 1:N
x = [1; X(n)];
% x = vertcat(ones(size(X)), X);
z = tanh(w1 * x);
y = w2 * z;
d2 = y - T(n);
d1 = (1 - z.^2) .* w2' * d2;
% d1(j) = (1 - z(j)^2) * w2(j) * d2;
w1 -= ETA * d1 * x';
w2 -= ETA * d2 * z';
end
errs(_) = sum_sq_error(xs, T, w1, w2);
%err1 = err2;
%err2 = sum_sq_error(xs, T, w1, w2);
%c += 1;
end
Z = tanh(w1 * vertcat(ones(size(X)), X));
Y = w2 * Z;
figure(1)
plot(X, T, 'o', X, Y)
% plot(errs(50:LOOP))
%%% reference :
%%% (1) パターン認識と機械学習
%%% (2) http://yuki-koyama.hatenablog.com/entry/2014/05/04/132552
%%% (3) http://taku-k.hatenablog.com/entry/2013/11/16/203644
% clear all;
N = 100; % number of sampling
RES = 200; % resolution of output curves
L = 1; %
%rbf = 'gauss';
x = linspace(-5, 5, N);
y = sin(x) + 0.25 * randn(size(x));
X = linspace(-5, 5, RES);
%if strcmp(rbf, 'gauss')
h = zeros(N, N);
for i = 1 : N
for j = 1 : N
r = abs(x(i) - x(j));
h(i, j) = normpdf(r);
end
end
w = pinv(h) * y';
w2 = (L * eye(N) + h' * h) \ (y * h)';
dist = abs(ones(RES, 1) * x - X' * ones(1, N));
%%% equivalent to:
%for i = 1 : NUM_PLOT
% dist = abs(X(i) * ones(1, N) - x);
% % Y(i) = dot(ws, normpdf(x, X(i), 1));
% Y(i) = dot(w, normpdf(dist));
%end
ND = normpdf(dist);
Y = ND * w;
Y2 = ND * w2;
%end
figure(1)
subplot(2, 1, 1);
plot(x, y, 'o', 'MarkerSize', 5, X, Y, X, Y2);
axis([-5, 5, -1.5, 1.5]);
title('Gaussian')
legend('sin(x) + \epsilon', 'least-squares solution', 'after regularization', 'Location', 'NorthEastOutside');
%if strcmp(rbf, 'poly')
M = 12; % degree of polynomial
h = zeros(N, M);
for j = 1 : M
h(:, j) = x .^ (j-1);
end
w = pinv(h) * y';
w2 = (L * eye(M) + h' * h) \ (y * h)';
X2 = zeros(size(X));
for i = 1 : M
X2(i, :) = X.^(i-1);
end
Y = w' * X2;
Y2 = w2' * X2;
%end
subplot(2, 1, 2);
plot(x, y, 'o', 'MarkerSize', 5, X, Y, X, Y2);
axis([-5, 5, -1.5, 1.5]);
title('Polynomial')
legend('sin(x) + \epsilon', 'least-squares solution', 'after regularization', 'Location', 'NorthEastOutside');
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment