Skip to content

Instantly share code, notes, and snippets.

@b-rodrigues
Created November 15, 2018 09:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save b-rodrigues/abe8e0a7e21d9609a075c803bdb6b680 to your computer and use it in GitHub Desktop.
Save b-rodrigues/abe8e0a7e21d9609a075c803bdb6b680 to your computer and use it in GitHub Desktop.
Searching for the optimal hyper-parameters of an ARIMA model in parallel: the tidy gridsearch approach
library(tidyverse)
library(forecast)
library(lubridate)
library(tsibble)
library(brotools)
ihs <- function(x){
log(x + sqrt(x**2 + 1))
}
to_tibble <- function(forecast_object){
point_estimate <- forecast_object$mean %>%
as_tsibble() %>%
rename(point_estimate = value,
date = index)
upper <- forecast_object$upper %>%
as_tsibble() %>%
spread(key, value) %>%
rename(date = index,
upper80 = `80%`,
upper95 = `95%`)
lower <- forecast_object$lower %>%
as_tsibble() %>%
spread(key, value) %>%
rename(date = index,
lower80 = `80%`,
lower95 = `95%`)
reduce(list(point_estimate, upper, lower), full_join)
}
avia_clean_monthly <- read_csv("https://raw.githubusercontent.com/b-rodrigues/avia_par_lu/master/avia_clean_monthy.csv")
avia_clean_train <- avia_clean_monthly %>%
select(date, passengers) %>%
filter(year(date) < 2015) %>%
group_by(date) %>%
summarise(total_passengers = sum(passengers)) %>%
pull(total_passengers) %>%
ts(., frequency = 12, start = c(2005, 1))
avia_clean_test <- avia_clean_monthly %>%
select(date, passengers) %>%
filter(year(date) >= 2015) %>%
group_by(date) %>%
summarise(total_passengers = sum(passengers)) %>%
pull(total_passengers) %>%
ts(., frequency = 12, start = c(2015, 1))
logged_train_data <- ihs(avia_clean_train)
logged_test_data <- ihs(avia_clean_test)
order_list <- list("p" = seq(0, 3),
"d" = seq(0, 2),
"q" = seq(0, 3)) %>%
cross() %>%
map(lift(c))
season_list <- list("P" = seq(0, 3),
"D" = seq(0, 2),
"Q" = seq(0, 3),
"period" = 12) %>%
cross() %>%
map(lift(c))
orderdf <- tibble("order" = order_list)
seasondf <- tibble("season" = season_list)
hyper_parameters_df <- crossing(orderdf, seasondf)
nrows <- nrow(hyper_parameters_df)
library(furrr)
plan(multiprocess, workers = 8) # Change to whatever number of CPUs you want to use for training
tic <- Sys.time()
models_df <- hyper_parameters_df %>%
mutate(models = future_map2(.x = order,
.y = season,
~possibly(arima, otherwise = NULL, quiet = FALSE)(x = logged_data,
order = .x, seasonal = .y)))
Sys.time() - tic
models_df <- models_df %>%
mutate(forecast = map(models, ~possibly(forecast, otherwise = NULL)(., h = 39))) %>%
mutate(point_forecast = map(forecast, ~.$`mean`)) %>%
mutate(true_value = rerun(nrows, logged_test_data)) %>%
mutate(rmse = map2_dbl(point_forecast, true_value,
~sqrt(mean((.x - .y) ** 2))))
best_model <- models_df %>%
filter(rmse == min(rmse, na.rm = TRUE))
(best_model_forecast <- to_tibble(best_model$forecast[[1]]))
avia_clean_monthly %>%
group_by(date) %>%
summarise(total = sum(passengers)) %>%
mutate(total_ihs = ihs(total)) %>%
ggplot() +
ggtitle("Logged data") +
geom_line(aes(y = total_ihs, x = date), colour = "#82518c") +
scale_x_date(date_breaks = "1 year", date_labels = "%m-%Y") +
geom_ribbon(data = best_model_forecast_furrr, aes(x = date, ymin = lower95, ymax = upper95), fill = "#666018", alpha = 0.2) +
geom_line(data = best_model_forecast_furrr, aes(x = date, y = point_estimate), linetype = 2, colour = "#8e9d98") +
theme_blog()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment