Skip to content

Instantly share code, notes, and snippets.

@dfalbel
Created October 18, 2017 18:40
Show Gist options
  • Save dfalbel/8df304eb4329dc5163ed949f183cdba3 to your computer and use it in GitHub Desktop.
Save dfalbel/8df304eb4329dc5163ed949f183cdba3 to your computer and use it in GitHub Desktop.
library(keras)
library(densenet)
input_img <- layer_input(shape = c(28, 28, 1))
model <- application_densenet(input_tensor = input_img, classes = 10L)
model %>% compile(
optimizer = "adam",
loss = "categorical_crossentropy",
metrics = "accuracy"
)
mnist <- dataset_mnist()
dim(mnist$train$x) <- c(dim(mnist$train$x), 1)
dim(mnist$test$x) <- c(dim(mnist$test$x), 1)
y <- sapply(0:9, function(x) as.numeric(x == mnist$train$y))
y_test <- sapply(0:9, function(x) as.numeric(x == mnist$test$y))
model %>% fit(
x = mnist$train$x,
y = y,
batch_size = 32,
validation_data = list(mnist$test$x, y_test)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment