Created
November 21, 2011 15:52
Functional and Parallel time series cross-validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Setup | |
set.seed(1) | |
library(fpp) # To load the data set a10 | |
HZ <- 12 | |
myControl <- list( minObs=60, | |
stepSize=1, | |
maxHorizon=HZ, | |
fixedWindow=FALSE, | |
summaryFunc=tsSummary | |
) | |
#Linear model | |
LM <- cv.ts(a10, lmForecast, tsControl=myControl, lambda=0) | |
LM | |
plot(LM[1:HZ,'MAE'],col=1,type='l',ylim=c(0.65,10), ylab='MAE', xlab='Horizon') | |
legend("topleft",legend=c('LM','Arima','ETS','RW','Theta','STS','stlETS','stlAR','Mean'), | |
col=1:9,lty=1) | |
#Arima | |
AR <- cv.ts(a10, arimaForecast, tsControl=myControl, | |
order=c(3,0,1), | |
seasonal=list(order=c(0,1,1), period=12), | |
include.drift=TRUE, | |
lambda=0, | |
method="ML") | |
AR | |
lines(AR[1:HZ,'MAE'],col=2,type='l') | |
#ETS | |
ETS <- cv.ts(a10, etsForecast, tsControl=myControl, model="MMM", damped=TRUE) | |
ETS | |
lines(ETS[1:HZ,'MAE'],col=3,type='l') | |
#Random Walk model | |
RW <- cv.ts(a10, rwForecast, tsControl=myControl, lambda=0) | |
RW | |
lines(RW[1:HZ,'MAE'],col=4,type='l', ylab='MAE', xlab='Horizon') | |
#Theta model | |
TM <- cv.ts(a10, thetaForecast, tsControl=myControl) | |
TM | |
lines(TM[1:HZ,'MAE'],col=5,type='l', ylab='MAE', xlab='Horizon') | |
#StructTS | |
sts <- cv.ts(a10, stsForecast, tsControl=myControl) | |
sts | |
lines(sts[1:HZ,'MAE'],col=6,type='l', ylab='MAE', xlab='Horizon') | |
#stl (ets) | |
stlETS <- cv.ts(a10, stl.etsForecast, tsControl=myControl) | |
stlETS | |
lines(stlETS[1:HZ,'MAE'],col=7,type='l', ylab='MAE', xlab='Horizon') | |
#stl (arima) | |
stlAR <- cv.ts(a10, stl.arimaForecast, tsControl=myControl) | |
stlAR | |
lines(stlAR[1:HZ,'MAE'],col=8,type='l', ylab='MAE', xlab='Horizon') | |
#Mean model | |
MM <- cv.ts(a10, meanForecast, tsControl=myControl, lambda=0) | |
MM | |
lines(MM[1:HZ,'MAE'],col=9,type='l', ylab='MAE', xlab='Horizon') | |
#Smaller Plot | |
#http://robjhyndman.com/researchtips/tscvexample/ | |
plot(LM[1:HZ,'MAE'],col=1,type='l',ylim=c(0.65,1.15), ylab='MAE', xlab='Horizon') | |
lines(AR[1:HZ,'MAE'],col=2,type='l') | |
lines(ETS[1:HZ,'MAE'],col=3,type='l') | |
lines(sts[1:HZ,'MAE'],col=6,type='l', ylab='MAE', xlab='Horizon') | |
lines(stlETS[1:HZ,'MAE'],col=7,type='l', ylab='MAE', xlab='Horizon') | |
lines(stlAR[1:HZ,'MAE'],col=8,type='l', ylab='MAE', xlab='Horizon') | |
legend("topleft",legend=c('LM','Arima','ETS','STS','stlETS','stlAR'), | |
col=c(1:3,6:8),lty=1) | |
#Replicate RH plot | |
#http://robjhyndman.com/researchtips/tscvexample/ | |
plot(1:12,LM[1:HZ,'MAE'], type="l", col=2, xlab="horizon", ylab="MAE", | |
ylim=c(0.65,1.05)) | |
lines(1:12, AR[1:HZ,'MAE'], type="l",col=3) | |
lines(1:12, ETS[1:HZ,'MAE'], type="l",col=4) | |
legend("topleft",legend=c("LM","ARIMA","ETS"),col=2:4,lty=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Define forecast methods | |
meanForecast <- function(x,h,...) { | |
require(forecast) | |
meanf(x, h, ..., level=99)$mean | |
} | |
rwForecast <- function(x,h,...) { | |
require(forecast) | |
rwf(x, h, ..., level=99)$mean | |
} | |
thetaForecast <- function(x,h,...) { | |
require(forecast) | |
thetaf(x, h, ..., level=99)$mean | |
} | |
lmForecast <- function(x,h,...) { | |
require(forecast) | |
x <- data.frame(x) | |
colnames(x) <- 'x' | |
fit <- tslm(x ~ trend + season, data=x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} | |
stsForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- StructTS(x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} | |
stl.etsForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- stl(x,s.window='periodic', ...) | |
forecast(fit, h=h, level=99, method='ets')$mean | |
} | |
stl.arimaForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- stl(x,s.window='periodic', ...) | |
forecast(fit, h=h, level=99, method='arima')$mean | |
} | |
arimaForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- Arima(x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} | |
etsForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- ets(x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Setup | |
set.seed(1) | |
library(fpp) # To load the data set a10 | |
HZ <- 12 | |
myControl <- list( minObs=60, | |
stepSize=1, | |
maxHorizon=HZ, | |
fixedWindow=FALSE, | |
summaryFunc=tsSummary | |
) | |
#Linear model | |
LM <- cv.ts(a10, lmForecast, tsControl=myControl, lambda=0) | |
#Arima | |
AR <- cv.ts(a10, arimaForecast, tsControl=myControl, | |
order=c(3,0,1), | |
seasonal=list(order=c(0,1,1), period=12), | |
include.drift=TRUE, | |
lambda=0, | |
method="ML") | |
#ETS | |
ETS <- cv.ts(a10, etsForecast, tsControl=myControl, model="MMM", damped=TRUE) | |
#Compare | |
LM | |
AR | |
ETS | |
#Plot | |
plot(1:12,LM[1:HZ,'MAE'], type="l", col=2, xlab="horizon", ylab="MAE", | |
ylim=c(0.65,1.05)) | |
lines(1:12, AR[1:HZ,'MAE'], type="l",col=3) | |
lines(1:12, ETS[1:HZ,'MAE'], type="l",col=4) | |
legend("topleft",legend=c("LM","ARIMA","ETS"),col=2:4,lty=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Wrapper functions | |
tsSummary <- function(P,A) { | |
data.frame(t(accuracy(P,A))) | |
} | |
#Define forecast methods | |
lmForecast <- function(x,h,...) { | |
require(forecast) | |
x <- data.frame(x) | |
colnames(x) <- 'x' | |
fit <- tslm(x ~ trend + season, data=x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} | |
arimaForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- Arima(x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} | |
etsForecast <- function(x,h,...) { | |
require(forecast) | |
fit <- ets(x, ...) | |
forecast(fit, h=h, level=99)$mean | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Setup | |
rm(list = ls(all = TRUE)) | |
#Function to cross-validate a time series. | |
cv.ts <- function(x, FUN, tsControl, ...) { | |
#Load required packages | |
stopifnot(is.ts(x)) | |
stopifnot(require(forecast)) | |
stopifnot(require(foreach)) | |
stopifnot(require(plyr)) | |
#Load parameters from the tsControl list | |
stepSize <- tsControl$stepSize | |
maxHorizon <- tsControl$maxHorizon | |
minObs <- tsControl$minObs | |
fixedWindow <- tsControl$fixedWindow | |
summaryFunc <- tsControl$summaryFunc | |
#Define additional parameters | |
freq <- frequency(x) | |
n <- length(x) | |
st <- tsp(x)[1]+(minObs-2)/freq | |
#Create a matrix of actual values, that we will later compare to forecasts. | |
#X is the point in time, Y is the forecast horizon | |
formatActuals <- function(x,maxHorizon) { | |
actuals <- outer(seq_along(x), seq_len(maxHorizon), FUN="+") | |
actuals <- apply(actuals,2,function(a) x[a]) | |
actuals | |
} | |
actuals <- formatActuals(x,maxHorizon) | |
actuals <- actuals[minObs:(length(x)-1),] | |
#At each point in time, calculate 'maxHorizon' forecasts ahead | |
#This is the 'Main Function' | |
forcasts <- foreach(i=1:(n-minObs), .combine=rbind, .multicombine=FALSE) %dopar% { | |
if (fixedWindow) { | |
xshort <- window(x, start=st+(i-minObs+1)/freq, end=st+i/freq) | |
} else { | |
xshort <- window(x, end=st + i/freq) | |
} | |
return(FUN(xshort, h=maxHorizon, ...)) | |
} | |
#Asess Accuracy at each horizon | |
out <- data.frame( | |
ldply(1:maxHorizon, | |
function(horizon) { | |
P <- forcasts[,horizon] | |
A <- na.omit(actuals[,horizon]) | |
P <- P[1:length(A)] | |
summaryFunc(P,A) | |
} | |
) | |
) | |
#Calculate mean accuracy across all horizons | |
overall <- colMeans(out) | |
out <- rbind(out,overall) | |
#Add a column for which horizon and output | |
return(data.frame(horizon=c(1:maxHorizon,'All'),out)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment