Created
November 15, 2018 09:09
-
-
Save b-rodrigues/abe8e0a7e21d9609a075c803bdb6b680 to your computer and use it in GitHub Desktop.
Searching for the optimal hyper-parameters of an ARIMA model in parallel: the tidy gridsearch approach
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(forecast) | |
library(lubridate) | |
library(tsibble) | |
library(brotools) | |
ihs <- function(x){ | |
log(x + sqrt(x**2 + 1)) | |
} | |
to_tibble <- function(forecast_object){ | |
point_estimate <- forecast_object$mean %>% | |
as_tsibble() %>% | |
rename(point_estimate = value, | |
date = index) | |
upper <- forecast_object$upper %>% | |
as_tsibble() %>% | |
spread(key, value) %>% | |
rename(date = index, | |
upper80 = `80%`, | |
upper95 = `95%`) | |
lower <- forecast_object$lower %>% | |
as_tsibble() %>% | |
spread(key, value) %>% | |
rename(date = index, | |
lower80 = `80%`, | |
lower95 = `95%`) | |
reduce(list(point_estimate, upper, lower), full_join) | |
} | |
avia_clean_monthly <- read_csv("https://raw.githubusercontent.com/b-rodrigues/avia_par_lu/master/avia_clean_monthy.csv") | |
avia_clean_train <- avia_clean_monthly %>% | |
select(date, passengers) %>% | |
filter(year(date) < 2015) %>% | |
group_by(date) %>% | |
summarise(total_passengers = sum(passengers)) %>% | |
pull(total_passengers) %>% | |
ts(., frequency = 12, start = c(2005, 1)) | |
avia_clean_test <- avia_clean_monthly %>% | |
select(date, passengers) %>% | |
filter(year(date) >= 2015) %>% | |
group_by(date) %>% | |
summarise(total_passengers = sum(passengers)) %>% | |
pull(total_passengers) %>% | |
ts(., frequency = 12, start = c(2015, 1)) | |
logged_train_data <- ihs(avia_clean_train) | |
logged_test_data <- ihs(avia_clean_test) | |
order_list <- list("p" = seq(0, 3), | |
"d" = seq(0, 2), | |
"q" = seq(0, 3)) %>% | |
cross() %>% | |
map(lift(c)) | |
season_list <- list("P" = seq(0, 3), | |
"D" = seq(0, 2), | |
"Q" = seq(0, 3), | |
"period" = 12) %>% | |
cross() %>% | |
map(lift(c)) | |
orderdf <- tibble("order" = order_list) | |
seasondf <- tibble("season" = season_list) | |
hyper_parameters_df <- crossing(orderdf, seasondf) | |
nrows <- nrow(hyper_parameters_df) | |
library(furrr) | |
plan(multiprocess, workers = 8) # Change to whatever number of CPUs you want to use for training | |
tic <- Sys.time() | |
models_df <- hyper_parameters_df %>% | |
mutate(models = future_map2(.x = order, | |
.y = season, | |
~possibly(arima, otherwise = NULL, quiet = FALSE)(x = logged_data, | |
order = .x, seasonal = .y))) | |
Sys.time() - tic | |
models_df <- models_df %>% | |
mutate(forecast = map(models, ~possibly(forecast, otherwise = NULL)(., h = 39))) %>% | |
mutate(point_forecast = map(forecast, ~.$`mean`)) %>% | |
mutate(true_value = rerun(nrows, logged_test_data)) %>% | |
mutate(rmse = map2_dbl(point_forecast, true_value, | |
~sqrt(mean((.x - .y) ** 2)))) | |
best_model <- models_df %>% | |
filter(rmse == min(rmse, na.rm = TRUE)) | |
(best_model_forecast <- to_tibble(best_model$forecast[[1]])) | |
avia_clean_monthly %>% | |
group_by(date) %>% | |
summarise(total = sum(passengers)) %>% | |
mutate(total_ihs = ihs(total)) %>% | |
ggplot() + | |
ggtitle("Logged data") + | |
geom_line(aes(y = total_ihs, x = date), colour = "#82518c") + | |
scale_x_date(date_breaks = "1 year", date_labels = "%m-%Y") + | |
geom_ribbon(data = best_model_forecast_furrr, aes(x = date, ymin = lower95, ymax = upper95), fill = "#666018", alpha = 0.2) + | |
geom_line(data = best_model_forecast_furrr, aes(x = date, y = point_estimate), linetype = 2, colour = "#8e9d98") + | |
theme_blog() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment