Skip to content

Instantly share code, notes, and snippets.

@MCodrescu
Created June 13, 2023 20:11
Show Gist options
  • Save MCodrescu/343b3fbc4bd7cf826aac6bf86ea280f2 to your computer and use it in GitHub Desktop.
Save MCodrescu/343b3fbc4bd7cf826aac6bf86ea280f2 to your computer and use it in GitHub Desktop.
training_images <- readr::read_csv(
"train.csv",
col_types = "cf"
)
testing_images <- readr::read_csv(
"test.csv",
col_types = "cf"
)
# Prepare Training Data
n <- length(training_images$filepath)
training_matrix <- matrix(
nrow = n,
ncol = 28 * 28
)
for (i in 1:n){
training_matrix[i, ] <-
as.vector(
png::readPNG(
training_images$filepath[i]
)
)
}
training_data <- cbind(
dplyr::select(training_images, label),
as.data.frame(
training_matrix
)
)
start <- Sys.time()
# Model Fitting
rf_fit <- parsnip::fit(
parsnip::rand_forest(
mode = "classification"
),
data = training_data,
formula = label ~ .
)
end <- Sys.time()
print(end - start)
# Save model
saveRDS(rf_fit, "mnist_model_fit.rds")
# Prepare Test Data
n_test <- length(testing_images$filepath)
testing_matrix <- matrix(
nrow = n,
ncol = 28 * 28
)
for (i in 1:n_test){
testing_matrix[i, ] <-
as.vector(
png::readPNG(
testing_images$filepath[i]
)
)
}
testing_data <- na.omit(
cbind(
dplyr::select(testing_images, label),
as.data.frame(
testing_matrix
)
)
)
# Model Evaluation
predictions <- predict(
rf_fit,
testing_data
)
final_result <-
dplyr::bind_cols(
predictions,
dplyr::select(
testing_data,
label
)
)
yardstick::metrics(
final_result,
truth = "label",
estimate = ".pred_class"
)
wrong_idx <- which(final_result$label != final_result$.pred_class)
right_idx <- which(final_result$label == final_result$.pred_class)
length(wrong_idx)
random_right <- sample(right_idx, 3)
random_wrong <- sample(wrong_idx, 3)
# Plot the mistakes
ggplot2::ggplot(
data = data.frame(
x = seq(1, 10, length.out = 6),
y = 1,
images = testing_images$filepath[c(random_right, random_wrong)]
),
ggplot2::aes(
x,
y,
image = images,
label = paste(final_result$.pred_class[c(random_right, random_wrong)])
)
) +
ggimage::geom_image(
size=.10
) +
ggplot2::scale_y_continuous(
limits = c(0, 2)
) +
ggplot2::scale_x_continuous(
limits = c(0, 11)
) +
ggplot2::geom_text(
size = 10,
nudge_y = 0.25,
color = c("green", "green", "green", "red", "red", "red")
) +
ggplot2::theme_void()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment