Skip to content

Instantly share code, notes, and snippets.

@toyeiei
Last active July 26, 2023 15:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save toyeiei/4eb7bb87eb3d92a8ad7f1a14da9aa804 to your computer and use it in GitHub Desktop.
Save toyeiei/4eb7bb87eb3d92a8ad7f1a14da9aa804 to your computer and use it in GitHub Desktop.
Train regularized logistic regression in R using caret package
## Regularized Logistic Regression in R
## Created on 12 May 2019
## Load dataset
library(mlbench)
data("BreastCancer")
## Clean data
## head(BreastCancer)
## str(BreastCancer)
mean(complete.cases(BreastCancer))
df <- na.omit(BreastCancer)
df$Id <- NULL
## Split data
set.seed(1)
id <- sample(1:nrow(df), 0.8*nrow(df))
train_df <- df[id, ]
test_df <- df[-id, ]
nrow(train_df); nrow(test_df)
## Regularization
library(caret)
library(glmnet)
## Train with 5-Fold CV
set.seed(1)
reguarlized_model <- train(Class ~ ., data = train_df,
method = "glmnet",
metric = "Accuracy",
tuneLength = 15, ## we'll try 15 alpha and 15 lambda values
trControl = trainControl(method = "cv",
number = 5,
search = "random", ## we use random search
verboseIter = T))
## Aggregating results
## Selecting tuning parameters
## Fitting alpha = 0.77, lambda = 0.021 on full training set
## Best Tuned Model
lasso_model$bestTune
## Train Accuracy
p3 <- predict(reguarlized_model, type = "prob")
p3 <- ifelse(p3[[2]] >= 0.5, T, F)
table(p3, train_df$Class)
print(sum(diag(table(p3, train_df$Class)))/ nrow(train_df))
## Test Accuracy
p4 <- predict(reguarlized_model, newdata = test_df, type = "prob")
p4 <- ifelse(p4[[2]] >= 0.5, T, F)
table(p4, test_df$Class)
print(sum(diag(table(p4, test_df$Class)))/ nrow(test_df))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment