Created
October 19, 2023 14:54
-
-
Save benmarwick/e68faf8b7d91f9f8400bba6fc5e9fc34 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# following example from https://tensorflow.rstudio.com/tutorials/keras/classification | |
# do these steps first : https://tensorflow.rstudio.com/install/ | |
library(keras) | |
library(tidyverse) | |
library(png) | |
library(magick) | |
# get images into R | |
imgs <- | |
list.files(path = "training-png/", | |
full.names = TRUE) | |
# TODO: split into testing and training set | |
class_names <- | |
basename(imgs) %>% | |
str_remove_all(., "-|\\d|.png") | |
class_names_int <- | |
as.factor(class_names) %>% | |
as.integer() %>% | |
array | |
# all the images need to have the same dimensions | |
# here we crop them from the centre point to | |
# create a square image, each image has the same | |
# height and width in pixels | |
dim1_pixels <- 30 | |
dim2_pixels <- 30 | |
imgs_centre_crop <- | |
map(imgs, | |
~magick::image_read(.x) %>% | |
magick::image_crop( | |
geometry=paste0(dim1_pixels, "x", dim2_pixels, "+0+0"), | |
gravity="center", | |
repage=TRUE) %>% | |
magick::image_convert( | |
colorspace = "gray") | |
) | |
imgs_centre_crop_int <- | |
map(imgs_centre_crop, | |
~image_data(.x, 'gray') %>% | |
as.integer() | |
) | |
# check does it look ok? | |
image_1 <- as.data.frame(imgs_centre_crop_int[[1]]) | |
colnames(image_1) <- seq_len(ncol(image_1)) | |
image_1$y <- seq_len(nrow(image_1)) | |
image_1 <- gather(image_1, "x", "value", -y) | |
image_1$x <- as.integer(image_1$x) | |
ggplot(image_1, | |
aes(x = x, y = y, | |
fill = value)) + | |
geom_tile() + | |
scale_fill_gradient(low = "white", | |
high = "black", | |
na.value = NA) + | |
scale_y_reverse() + | |
theme_minimal() + | |
theme(panel.grid = element_blank()) + | |
theme(aspect.ratio = 1) + | |
xlab("") + | |
ylab("") | |
# convert to array and normalise values to fit in 0-1 | |
imgs_centre_crop_int_array <- | |
simplify2array(imgs_centre_crop_int) | |
imgs_centre_crop_int_array <- | |
drop(imgs_centre_crop_int_array) | |
imgs_centre_crop_int_array_perm <- | |
aperm(imgs_centre_crop_int_array, | |
c(3, 1, 2)) | |
imgs_centre_crop_int_array_perm_norm <- | |
imgs_centre_crop_int_array_perm / 255 | |
# check does it look ok? | |
image_1 <- as.data.frame(imgs_centre_crop_int_array_perm_norm[1, , ]) | |
colnames(image_1) <- seq_len(ncol(image_1)) | |
image_1$y <- seq_len(nrow(image_1)) | |
image_1 <- gather(image_1, "x", "value", -y) | |
image_1$x <- as.integer(image_1$x) | |
ggplot(image_1, aes(x = x, y = y, fill = value)) + | |
geom_tile() + | |
scale_fill_gradient(low = "white", high = "black", na.value = NA) + | |
scale_y_reverse() + | |
theme_minimal() + | |
theme(panel.grid = element_blank()) + | |
theme(aspect.ratio = 1) + | |
xlab("") + | |
ylab("") | |
par(mfcol=c(5,5)) | |
par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i') | |
for (i in 1:25) { | |
img <- imgs_centre_crop_int_array_perm_norm[i, , ] | |
img <- t(apply(img, 2, rev)) | |
image(1:length(imgs), 1:length(imgs), img, | |
xaxt = 'n', yaxt = 'n', | |
main = paste(class_names[i])) | |
} | |
model <- keras_model_sequential() | |
model %>% | |
layer_flatten(input_shape = c(dim1_pixels, dim2_pixels)) %>% | |
layer_dense(units = 128, activation = 'relu') %>% | |
layer_dense(units = 10, activation = 'softmax') | |
model %>% compile( | |
optimizer = 'adam', | |
loss = 'sparse_categorical_crossentropy', | |
metrics = c('accuracy') | |
) | |
model %>% fit(imgs_centre_crop_int_array_perm_norm, | |
class_names_int, | |
epochs = 5, | |
verbose = 2) | |
score <- | |
model %>% | |
evaluate(imgs_centre_crop_int_array_perm_norm, | |
class_names_int, | |
verbose = 0) | |
cat('Test loss:', score["loss"], "\n") | |
predictions <- model %>% predict(imgs_centre_crop_int_array_perm_norm) | |
par(mfcol=c(5,5)) | |
par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i') | |
for (i in 1:25) { | |
img <- imgs_centre_crop_int_array_perm_norm[i, , ] | |
img <- t(apply(img, 2, rev)) | |
# subtract 1 as labels go from 0 to 9 | |
predicted_label <- which.max(predictions[i, ]) | |
true_label <- class_names[i] | |
if (predicted_label == true_label) { | |
color <- '#008800' | |
} else { | |
color <- '#bb0000' | |
} | |
image(1:dim1_pixels, 1:dim2_pixels, img, col = gray((0:255)/255), xaxt = 'n', yaxt = 'n', | |
main = paste0(class_names[predicted_label], " (", | |
class_names[true_label], ")"), | |
col.main = color) | |
} | |
# Grab an image from the test dataset | |
# take care to keep the batch dimension, as this is expected by the model | |
img <- imgs_centre_crop_int_array_perm_norm[1, , , drop = FALSE] | |
dim(img) | |
predictions <- model %>% predict(img) | |
predictions | |
# subtract 1 as labels are 0-based | |
prediction <- predictions[1, ] | |
which.max(prediction) | |
# get name | |
class_names[which.max(prediction)] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment