Skip to content

Instantly share code, notes, and snippets.

@simonpcouch
Created September 1, 2022 19:12
Show Gist options
  • Save simonpcouch/3efb7486e77c886a40e6815c4f42567c to your computer and use it in GitHub Desktop.
Save simonpcouch/3efb7486e77c886a40e6815c4f42567c to your computer and use it in GitHub Desktop.
library(tidymodels)
library(stacks)
library(bonsai)

tidymodels_prefer()

# regression ------------------------------------------------------------------
reg_bt <-
  boost_tree(mtry = tune()) %>%
  set_engine("lightgbm") %>%
  set_mode("regression")

set.seed(1)

reg_st_bt_time <- 
  system.time({  
    reg_st_bt <-
      stacks() %>%
      add_candidates(reg_res_lr) %>%
      add_candidates(reg_res_svm) %>%
      add_candidates(reg_res_sp) %>%
      blend_predictions(meta_learner = reg_bt) %>%
      fit_members()
  })
#> i Creating pre-processing data to finalize unknown parameter: mtry

reg_st_bt_preds <- predict(reg_st_bt, tree_frogs_reg_test)

set.seed(1)

reg_st_glmnet_time <- 
  system.time({ 
    reg_st_glmnet <-
      stacks() %>%
      add_candidates(reg_res_lr) %>%
      add_candidates(reg_res_svm) %>%
      add_candidates(reg_res_sp) %>%
      blend_predictions() %>%
      fit_members()
  })

reg_st_glmnet_preds <- predict(reg_st_glmnet, tree_frogs_reg_test)

# comparing time to fit
reg_st_bt_time[["elapsed"]] / reg_st_glmnet_time[["elapsed"]]
#> [1] 1.094234

# comparing rmse
rmse_vec(tree_frogs_reg_test$latency, reg_st_bt_preds$.pred)
#> [1] 60.39046
rmse_vec(tree_frogs_reg_test$latency, reg_st_glmnet_preds$.pred)
#> [1] 54.13734

# classification --------------------------------------------------------------
class_bt <-
  boost_tree(mtry = tune(), min_n = tune(), tree_depth = tune()) %>%
  set_engine("lightgbm") %>%
  set_mode("classification")

set.seed(1)

class_st_bt_time <- 
  system.time({
    class_st_bt <-
      stacks() %>%
      add_candidates(class_res_rf) %>%
      add_candidates(class_res_nn) %>%
      blend_predictions(meta_learner = class_bt) %>%
      fit_members()
  })
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.
#> i Creating pre-processing data to finalize unknown parameter: mtry

class_st_bt_preds <- predict(class_st_bt, tree_frogs_class_test)

set.seed(1)

class_st_glmnet_time <- 
  system.time({
    class_st_glmnet <-
      stacks() %>%
      add_candidates(class_res_rf) %>%
      add_candidates(class_res_nn) %>%
      blend_predictions() %>%
      fit_members()
  })
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.

class_st_glmnet_preds <- predict(class_st_glmnet, tree_frogs_class_test)

# comparing accuracy
accuracy_vec(tree_frogs_class_test$reflex, class_st_bt_preds$.pred_class)
#> [1] 0.8679868
accuracy_vec(tree_frogs_class_test$reflex, class_st_glmnet_preds$.pred_class)
#> [1] 0.9009901

# comparing time to fit
class_st_bt_time[["elapsed"]] / class_st_glmnet_time[["elapsed"]]
#> [1] 0.3143625

Created on 2022-09-01 by the reprex package (v2.0.1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment