# spam or ham | |
# load library | |
library(kernlab) | |
library(naivebayes) | |
library(dplyr) | |
# load dataset into RStudio | |
data("spam") | |
# tibble dataframe | |
spam <- as.tbl(spam) | |
glimpse(spam) # 58 variables, 4601 records | |
# type variable | |
table(spam$type) | |
table(spam$type) / nrow(spam) # imbalanced classes | |
# any missing value | |
mean(complete.cases(spam)) # no missing value | |
# split dataset | |
set.seed(99) | |
id <- sample(nrow(spam), .8*nrow(spam)) | |
train_set <- spam[id, ]; nrow(train_set) | |
test_set <- spam[-id, ]; nrow(test_set) | |
# train model | |
nb_model <- naive_bayes(type ~ ., data = train_set, laplace = 2) | |
# create a summary function | |
nb_summary <- function(nb_model, train_set, test_set){ | |
train_acc <- mean(predict(nb_model) == train_set$type) | |
test_acc <- mean(predict(nb_model, newdata = test_set) == test_set$type) | |
cat("===== Model Summary =====\n") | |
cat("train accuracy: ", round(train_acc*100, 2), "%", sep = "") | |
cat("\n") | |
cat(" test accuracy: ", round(test_acc*100, 2), "%", sep = "") | |
} | |
# see the accuracy of train and test data | |
nb_summary(nb_model, train_set, test_set) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment