Skip to content

Instantly share code, notes, and snippets.

@ivopbernardo
Last active November 2, 2022 09:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ivopbernardo/f14d54e8b57147b2a751a15ed0dc842d to your computer and use it in GitHub Desktop.
Save ivopbernardo/f14d54e8b57147b2a751a15ed0dc842d to your computer and use it in GitHub Desktop.
Training Hyperparameters with mlr Blog Post
# mlr library example clode - used in blog post:
# https://towardsdatascience.com/decision-tree-hyperparameter-tuning-in-r-using-mlr-3248bfd2d88c
titanic <- read.csv('train.csv')
library(dplyr)
library(rpart)
library(rpart.plot)
library(Metrics)
library(mlr)
library(ggplot2)
library(plotly)
# Sample Fraud Data to speed up execution
set.seed(123)
# Subset Columns for Decision Tree
titanic <- titanic %>%
select(Fare, Age, Sex, Pclass, Survived, SibSp, Parch)
# Splitting data into Train and Test
titanic['row_id'] = rownames(titanic)
set.seed(123)
train_data <- titanic %>%
sample_frac(0.8)
test_data <- titanic %>%
anti_join(train_data, by='row_id')
# Drop row_id from both dataframes
train_data[,'row_id'] <- NULL
test_data[,'row_id'] <- NULL
# Building our first decision tree
d.tree = rpart(Survived ~ .,
data=train_data,
method = 'class')
# Plotting our Tree
rpart.plot(d.tree, cex=0.55)
# Predict Values
predicted_values <- predict(d.tree, test_data, type = 'class')
# Getting Accuracy
accuracy(test_data$Survived, predicted_values)
# Building our d.tree with custom paremeters
d.tree.custom = rpart(Survived~ .,
data=train_data,
method = 'class',
control = c(maxdepth = 5, cp=0.001))
rpart.plot(d.tree.custom, cex=0.6)
# Predict test set data
predicted_values.custom <- predict(d.tree.custom, test_data, type = 'class')
# Accuracy of Custom D.Tree
accuracy(test_data$Survived, predicted_values.custom)
# Hyperparameter Tuning training with mlr
getParamSet("classif.rpart")
d.tree.mlr <- makeClassifTask(
data=train_data,
target="Survived"
)
# Search Parameter for Max Depth
param_grid <- makeParamSet(
makeDiscreteParam("maxdepth", values=1:30))
# Define Grid
control_grid = makeTuneControlGrid()
# Define Cross Validation
resample = makeResampleDesc("CV", iters = 3L)
# Define Measure
measure = acc
set.seed(123)
dt_tuneparam <- tuneParams(learner='classif.rpart',
task=d.tree.mlr,
resampling = resample,
measures = measure,
par.set=param_grid,
control=control_grid,
show.info = TRUE)
# Extracting results
result_hyperparam <- generateHyperParsEffectData(dt_tuneparam, partial.dep = TRUE)
# Plotting Accuracy Score across different maxdepth
ggplot(
data = result_hyperparam$data,
aes(x = maxdepth, y=acc.test.mean)
) + geom_line(color = 'darkblue')
dt_tuneparam
# Pick Up Best Params and train them
best_parameters = setHyperPars(
makeLearner("classif.rpart", predict.type = "prob"),
par.vals = dt_tuneparam$x
)
best_model = train(best_parameters, d.tree.mlr)
d.tree.mlr.test <- makeClassifTask(
data=test_data,
target="Survived"
)
# Predicting the best Model
results <- predict(best_model, task = d.tree.mlr.test)$data
accuracy(results$truth, results$response)
# Tweaking multiple hyperparameters
param_grid_multi <- makeParamSet(
makeDiscreteParam("maxdepth", values=1:30),
makeNumericParam("cp", lower = 0.001, upper = 0.01),
makeDiscreteParam("minsplit", values=1:10)
)
dt_tuneparam_multi <- tuneParams(learner='classif.rpart',
task=d.tree.mlr,
resampling = resample,
measures = measure,
par.set=param_grid_multi,
control=control_grid,
show.info = TRUE)
# Extracting best Parameters from Multi Search
best_parameters_multi = setHyperPars(
makeLearner("classif.rpart", predict.type = "prob"),
par.vals = dt_tuneparam_multi$x
)
best_model_multi = train(best_parameters_multi, d.tree.mlr)
# Predicting the best Model
results <- predict(best_model_multi, task = d.tree.mlr.test)$data
accuracy(results$truth, results$response)
# Extracting results from multigrid
result_hyperparam.multi <- generateHyperParsEffectData(dt_tuneparam_multi, partial.dep = TRUE)
# Sampling just for visualization
result_sample <- result_hyperparam.multi$data %>%
sample_n(300)
hyperparam.plot <- plot_ly(result_sample,
x = ~cp,
y = ~maxdepth,
z = ~minsplit,
marker = list(color = ~acc.test.mean, colorscale = list(c(0, 1), c("darkred", "darkgreen")), showscale = TRUE))
hyperparam.plot <- hyperparam.plot %>% add_markers()
hyperparam.plot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment