Skip to content

Instantly share code, notes, and snippets.

@benmarwick
Created October 19, 2023 14:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benmarwick/e68faf8b7d91f9f8400bba6fc5e9fc34 to your computer and use it in GitHub Desktop.
Save benmarwick/e68faf8b7d91f9f8400bba6fc5e9fc34 to your computer and use it in GitHub Desktop.
# following example from https://tensorflow.rstudio.com/tutorials/keras/classification
# do these steps first : https://tensorflow.rstudio.com/install/
library(keras)
library(tidyverse)
library(png)
library(magick)
# get images into R
imgs <-
list.files(path = "training-png/",
full.names = TRUE)
# TODO: split into testing and training set
class_names <-
basename(imgs) %>%
str_remove_all(., "-|\\d|.png")
class_names_int <-
as.factor(class_names) %>%
as.integer() %>%
array
# all the images need to have the same dimensions
# here we crop them from the centre point to
# create a square image, each image has the same
# height and width in pixels
dim1_pixels <- 30
dim2_pixels <- 30
imgs_centre_crop <-
map(imgs,
~magick::image_read(.x) %>%
magick::image_crop(
geometry=paste0(dim1_pixels, "x", dim2_pixels, "+0+0"),
gravity="center",
repage=TRUE) %>%
magick::image_convert(
colorspace = "gray")
)
imgs_centre_crop_int <-
map(imgs_centre_crop,
~image_data(.x, 'gray') %>%
as.integer()
)
# check does it look ok?
image_1 <- as.data.frame(imgs_centre_crop_int[[1]])
colnames(image_1) <- seq_len(ncol(image_1))
image_1$y <- seq_len(nrow(image_1))
image_1 <- gather(image_1, "x", "value", -y)
image_1$x <- as.integer(image_1$x)
ggplot(image_1,
aes(x = x, y = y,
fill = value)) +
geom_tile() +
scale_fill_gradient(low = "white",
high = "black",
na.value = NA) +
scale_y_reverse() +
theme_minimal() +
theme(panel.grid = element_blank()) +
theme(aspect.ratio = 1) +
xlab("") +
ylab("")
# convert to array and normalise values to fit in 0-1
imgs_centre_crop_int_array <-
simplify2array(imgs_centre_crop_int)
imgs_centre_crop_int_array <-
drop(imgs_centre_crop_int_array)
imgs_centre_crop_int_array_perm <-
aperm(imgs_centre_crop_int_array,
c(3, 1, 2))
imgs_centre_crop_int_array_perm_norm <-
imgs_centre_crop_int_array_perm / 255
# check does it look ok?
image_1 <- as.data.frame(imgs_centre_crop_int_array_perm_norm[1, , ])
colnames(image_1) <- seq_len(ncol(image_1))
image_1$y <- seq_len(nrow(image_1))
image_1 <- gather(image_1, "x", "value", -y)
image_1$x <- as.integer(image_1$x)
ggplot(image_1, aes(x = x, y = y, fill = value)) +
geom_tile() +
scale_fill_gradient(low = "white", high = "black", na.value = NA) +
scale_y_reverse() +
theme_minimal() +
theme(panel.grid = element_blank()) +
theme(aspect.ratio = 1) +
xlab("") +
ylab("")
par(mfcol=c(5,5))
par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i')
for (i in 1:25) {
img <- imgs_centre_crop_int_array_perm_norm[i, , ]
img <- t(apply(img, 2, rev))
image(1:length(imgs), 1:length(imgs), img,
xaxt = 'n', yaxt = 'n',
main = paste(class_names[i]))
}
model <- keras_model_sequential()
model %>%
layer_flatten(input_shape = c(dim1_pixels, dim2_pixels)) %>%
layer_dense(units = 128, activation = 'relu') %>%
layer_dense(units = 10, activation = 'softmax')
model %>% compile(
optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = c('accuracy')
)
model %>% fit(imgs_centre_crop_int_array_perm_norm,
class_names_int,
epochs = 5,
verbose = 2)
score <-
model %>%
evaluate(imgs_centre_crop_int_array_perm_norm,
class_names_int,
verbose = 0)
cat('Test loss:', score["loss"], "\n")
predictions <- model %>% predict(imgs_centre_crop_int_array_perm_norm)
par(mfcol=c(5,5))
par(mar=c(0, 0, 1.5, 0), xaxs='i', yaxs='i')
for (i in 1:25) {
img <- imgs_centre_crop_int_array_perm_norm[i, , ]
img <- t(apply(img, 2, rev))
# subtract 1 as labels go from 0 to 9
predicted_label <- which.max(predictions[i, ])
true_label <- class_names[i]
if (predicted_label == true_label) {
color <- '#008800'
} else {
color <- '#bb0000'
}
image(1:dim1_pixels, 1:dim2_pixels, img, col = gray((0:255)/255), xaxt = 'n', yaxt = 'n',
main = paste0(class_names[predicted_label], " (",
class_names[true_label], ")"),
col.main = color)
}
# Grab an image from the test dataset
# take care to keep the batch dimension, as this is expected by the model
img <- imgs_centre_crop_int_array_perm_norm[1, , , drop = FALSE]
dim(img)
predictions <- model %>% predict(img)
predictions
# subtract 1 as labels are 0-based
prediction <- predictions[1, ]
which.max(prediction)
# get name
class_names[which.max(prediction)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment