Skip to content

Instantly share code, notes, and snippets.

@mikemahoney218
Last active August 31, 2022 07:21
Show Gist options
  • Save mikemahoney218/8a3e3a1bcd97358d231a3e760ccb1e13 to your computer and use it in GitHub Desktop.
Save mikemahoney218/8a3e3a1bcd97358d231a3e760ccb1e13 to your computer and use it in GitHub Desktop.
Proof-of-concept replication of https://geocompr.robinlovelace.net/spatial-cv.html with tidymodels
# 12.5 Spatial CV (with spatialsample)
library(tidymodels)
library(spatialsample)
library(sf)
data("lsl", "study_mask", package = "spDataLarge")
lsl <- lsl |>
st_as_sf(coords = c("x", "y"), crs = "EPSG:32717")
ta <- terra::rast(system.file("raster/ta.tif", package = "spDataLarge"))
# 12.5.1 Generalized linear model
glm_model <- logistic_reg() |>
set_engine("glm") |>
set_mode("classification")
glm_wflow <- workflow() |>
add_formula(lslpts ~ slope + cplan + cprof + elev + log10_carea) |>
add_model(glm_model)
set.seed(123)
lsl_folds <- spatial_clustering_cv(lsl, v = 5)
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
glm_wflow |>
fit_resamples(resamples = lsl_folds, control = keep_pred) |>
collect_metrics()
# 12.5.2 Spatial tuning of machine-learning hyperparameters
svm_model <- svm_rbf(cost = tune(), rbf_sigma = tune()) |>
set_engine("kernlab", prob.model = TRUE) |>
set_mode("classification")
svm_wflow <- workflow() |>
add_formula(lslpts ~ slope + cplan + cprof + elev + log10_carea) |>
add_model(svm_model)
trans_raise <- trans_new("range", \(x) -log2(x), \(x) 2^x)
set.seed(123)
svm_grid <- svm_wflow |>
extract_parameter_set_dials() |>
update(
cost = cost(c(-12, 15), trans = trans_raise),
rbf_sigma = rbf_sigma(c(-15, 6), trans = trans_raise)
) |>
grid_random(size = 50)
set.seed(123)
lsl_folds <- spatial_clustering_cv(lsl, v = 5)
set.seed(123)
svm_tune <- svm_wflow |>
tune_grid(
lsl_folds,
grid = svm_grid,
metrics = metric_set(roc_auc)
)
select_best(svm_tune, metric = "roc_auc")
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
svm_wflow |>
finalize_workflow(select_best(svm_tune, metric = "roc_auc")) |>
fit_resamples(resamples = lsl_folds, control = keep_pred) |>
collect_metrics()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment