Created
March 18, 2010 02:00
-
-
Save auroranockert/335971 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% sunspot_fftd | |
% | |
% This function trains a focused time-delay network for the sunspot data. | |
% | |
% May 2010, Mattias Ohlsson | |
% Email: mattias@thep.lu.se | |
% Clear all things | |
clear all; | |
close all; | |
% Load training and validation data | |
[P,T,Yr] = loadsuns(1); | |
[Pv,Tv,Yrv] = loadsuns(2); | |
% Rescale | |
sigma = std(T); | |
mu = mean(T); | |
Pn = (P - mu)/sigma; | |
Tn = (T - mu)/sigma; | |
Pvn = (Pv - mu)/sigma; | |
Tvn = (Tv - mu)/sigma; | |
% Now matlab wants time sequences in cell arrays | |
Pns = con2seq(Pn); | |
Tns = con2seq(Tn); | |
Pvns = con2seq(Pvn); | |
Tvns = con2seq(Tvn); | |
% Ask for training method | |
disp(sprintf('Choose metod:')); | |
disp(sprintf('Gradient descent (momentum and adaptiv lr) = 1')); | |
disp(sprintf('Powell-Beale conjugate gradient = 2')); | |
disp(sprintf('Levenberg-Marquardt = 3')); | |
tmp = input(' method? (default 3) = '); | |
if isempty(tmp) == 1 | |
method='trainlm'; | |
elseif tmp == 1 | |
method='traingdx'; | |
elseif tmp == 2 | |
method='traincgb'; | |
elseif tmp == 3 | |
method='trainlm'; | |
end | |
% Ask for the number of hidden nodes | |
tmp = input('Number of hidden nodes? [2] '); | |
if tmp > 0 | |
nodes=tmp; | |
else | |
nodes=2; | |
end | |
% Ask for the size of the time delay | |
tmp = input('Time dealy? [6] '); | |
if tmp > 0 | |
td=tmp; | |
else | |
td=6; | |
end | |
% Use regularization | |
tmp = input('Use regularization? (1/0) [0] '); | |
if tmp > 0 | |
net.performFcn = 'msereg'; | |
tmp = input('Perfomance ratio: [0.5] '); | |
if tmp > 0 | |
net.performParam.ratio = tmp; | |
else | |
net.performParam.ratio = 0.5; | |
end | |
end | |
% Create a ftdnn with initialization | |
% net = newfftd(Pns,Tns,[0:td],nodes,{},[method]); | |
% net = newdtdnn(Pns, Tns, nodes, {0:td}, {}, [method]); | |
net = newlrn(Pns, Tns, nodes, {}, [method]); | |
net.outputs{2}.processFcns = {}; % To avoid rescaling of outputs | |
net.trainParam.showCommandLine = 1; % To show the error on the commandline | |
net.divideFcn = ''; % To avoid division of the data into validation and test | |
% Train this network with train | |
tmp = input('Number of epochs? [200] '); | |
if tmp > 0 | |
epoch=tmp; | |
else | |
epoch=200; | |
end | |
net.trainParam.epochs = epoch; | |
% Train, It seems difficult to get the validation error during training! | |
net = train(net,Pns,Tns); | |
%%%%%%%%%%%% Evaluate the trained network %%%%%%%%%%%%%%% | |
%%%%%%%%% Single Step Prediction %%%%%% | |
tmp1 = sim(net,Pns); | |
tmp2 = sim(net,Pvns); | |
% And take it back to non-cell array form | |
tmp = seq2con(tmp1); Yn = tmp{1,1}; | |
tmp = seq2con(tmp2); Yvn = tmp{1,1}; | |
% Inverse of normalization | |
Y = Yn * sigma + mu; | |
Yv = Yvn * sigma + mu; | |
% The Size | |
[N,M] = size(Y); | |
[Nt,Mt] = size(Yv); | |
%%%%%%%%% The dummy prediction %%%%%% | |
Yvdn(1) = Pn(1,1); | |
Yvdn(2:Mt) = Tvn(1:Mt-1); | |
Yvd = Yvdn * sigma + mu; | |
%%%%%%%%% Calculate errors %%%%%% | |
etrain = Y - T; | |
eval = Yv - Tv; | |
evald = Yvd - Tv; | |
NMSEtrain = etrain*etrain' / (M*std(T(1,:))^2); | |
NMSEval = eval*eval' / (Mt*std(Tv(1,:))^2); | |
NMSEvald = evald*evald' / (Mt*std(Tv(1,:))^2); | |
%%%%%%%%% Plot the single step prediction for the training set | |
figres = figure; | |
subplot(2,1,1); | |
plot(Yr,T,'b-') | |
hold on | |
plot(Yr,Y,'r--') | |
hold off | |
Err = sprintf('%7.4f',NMSEtrain); | |
title(['Training: Single step prediction (red = network) : NMSE = ' Err]); | |
%%%%%%%%% Plot the single step prediction for the validation set | |
subplot(2,1,2) | |
plot(Yrv,Tv,'b-') | |
hold on | |
plot(Yrv,Yv,'r--') | |
hold off | |
Err = sprintf('%7.4f',NMSEval); | |
Errd = sprintf('%7.4f',NMSEvald); | |
title(['Test: Single step prediction (red = network) : NMSE = ' Err ... | |
' (dummy ' Errd ' )']); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment