Created
December 12, 2011 16:26
-
-
Save zachmayer/1468089 to your computer and use it in GitHub Desktop.
Time series cross-validation 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Function to cross-validate a time series. | |
cv.ts <- function(x, FUN, tsControl, xreg=NULL, ...) { | |
#Load required packages | |
stopifnot(is.ts(x)) | |
stopifnot(is.data.frame(xreg) | is.matrix(xreg) | is.null(xreg)) | |
stopifnot(require(forecast)) | |
stopifnot(require(foreach)) | |
stopifnot(require(plyr)) | |
#Load parameters from the tsControl list | |
stepSize <- tsControl$stepSize | |
maxHorizon <- tsControl$maxHorizon | |
minObs <- tsControl$minObs | |
fixedWindow <- tsControl$fixedWindow | |
summaryFunc <- tsControl$summaryFunc | |
#Make sure xreg object is long enough for last set of forecasts | |
if (! is.null(xreg)) { | |
xreg <- as.matrix(xreg) | |
if (nrow(xreg)<length(x)+maxHorizon) { | |
warning('xreg object too short to forecast beyond the length of the time series. | |
Appending NA values to xreg') | |
nRows <- (length(x)+maxHorizon)-nrow(xreg) | |
nCols <- dim(xreg)[2] | |
addRows <- matrix(rep(NA,nCols*nRows),nrow=nRows, ncol=nCols) | |
colnames(addRows) <- colnames(xreg) | |
xreg <- rbind(xreg,addRows) | |
} | |
} | |
#Define additional parameters | |
freq <- frequency(x) | |
n <- length(x) | |
st <- tsp(x)[1]+(minObs-2)/freq | |
#Create a matrix of actual values. | |
#X is the point in time, Y is the forecast horizon | |
#http://stackoverflow.com/questions/8140577/creating-a-matrix-of-future-values-for-a-time-series | |
formatActuals <- function(x,maxHorizon) { | |
actuals <- outer(seq_along(x), seq_len(maxHorizon), FUN="+") | |
actuals <- apply(actuals,2,function(a) x[a]) | |
actuals | |
} | |
actuals <- formatActuals(x,maxHorizon) | |
actuals <- actuals[minObs:(length(x)-1),,drop=FALSE] | |
#Create a list of training windows | |
#Each entry of this list will be the same length, if fixed=TRUE | |
#At each point in time, calculate 'maxHorizon' forecasts ahead | |
steps <- seq(1,(n-minObs),by=stepSize) | |
forcasts <- foreach(i=steps, .combine=rbind, .multicombine=FALSE) %dopar% { | |
if (is.null(xreg)) { | |
if (fixedWindow) { | |
xshort <- window(x, start=st+(i-minObs+1)/freq, end=st+i/freq) | |
} else { | |
xshort <- window(x, end=st + i/freq) | |
} | |
return(FUN(xshort, h=maxHorizon, ...)) | |
} else if (! is.null(xreg)) { | |
if (fixedWindow) { | |
xshort <- window(x, start=st+(i-minObs+1)/freq, end=st+i/freq) | |
xregshort <- xreg[((i):(i+minObs-1)),,drop=FALSE] | |
} else { | |
xshort <- window(x, end=st + i/freq) | |
xregshort <- xreg[(1:(i+minObs-1)),,drop=FALSE] | |
} | |
newxreg <- xreg[(i+minObs):(i+minObs-1+maxHorizon),,drop=FALSE] | |
return(FUN(xshort, h=maxHorizon, xreg=xregshort, newxreg=newxreg, ...)) | |
} | |
} | |
#Extract the actuals we actually want to use | |
actuals <- actuals[steps,,drop=FALSE] | |
#Accuracy at each horizon | |
out <- data.frame( | |
ldply(1:maxHorizon, | |
function(horizon) { | |
P <- forcasts[,horizon,drop=FALSE] | |
A <- na.omit(actuals[,horizon,drop=FALSE]) | |
P <- P[1:length(A)] | |
P <- na.omit(P) | |
A <- A[1:length(P)] | |
summaryFunc(P,A) | |
} | |
) | |
) | |
#Add average accuracy, across all horizons | |
overall <- colMeans(out) | |
out <- rbind(out,overall) | |
#Add a column for which horizon and output | |
return(data.frame(horizon=c(1:maxHorizon,'All'),out)) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Wrapper functions | |
tsSummary <- function(P,A) { | |
data.frame(t(accuracy(P,A))) | |
} | |
lmForecast <- function(x,h,xreg=NULL,newxreg=NULL,...) { | |
require(forecast) | |
x <- data.frame(x) | |
colnames(x) <- 'x' | |
if (is.null(xreg) & is.null(newxreg)) { | |
fit <- tslm(x ~ trend + season, data=x, ...) | |
return(forecast(fit, h=h, level=99)$mean) | |
} else if ((!is.null(xreg)) & !(is.null(newxreg))) { | |
newnames <- c('x',colnames(xreg)) | |
x <- cbind(x,xreg) | |
colnames(x) <- newnames | |
fmla <- as.formula(paste("x ~ trend + season +", paste(colnames(xreg), collapse= "+"))) | |
fit <- tslm(fmla, data=x) | |
return(forecast(fit, h=h, level=99, newdata=newxreg)$mean) | |
} else { | |
stop('xreg and newxreg must both be NULL or both be provided') | |
} | |
} | |
arimaForecast <- function(x,h,xreg=NULL,newxreg=NULL,...) { | |
fit <- Arima(x, xreg=xreg, ...) | |
forecast(fit, h=h, level=99, xreg=newxreg)$mean | |
} | |
#Create xregs | |
library(forecast) | |
X <- fourier(ldeaths,3) | |
#Cross-validate models | |
myControl <- list( minObs=12, | |
stepSize=1, | |
maxHorizon=12, | |
fixedWindow=FALSE, | |
summaryFunc=tsSummary | |
) | |
lmResult <- cv.ts(ldeaths, lmForecast, myControl, xreg=X) | |
arimaResult <- cv.ts(ldeaths, arimaForecast, myControl, order=c(1,0,1), method="ML", xreg=X) | |
#Examine results | |
lmResult | |
arimaResult | |
plot(lmResult[1:12,'MAE'], type='l') | |
lines(arimaResult[1:12,'MAE'], col=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment