Skip to content

Instantly share code, notes, and snippets.

@JulienPascal
Last active January 13, 2020 15:22
Show Gist options
  • Save JulienPascal/87c2e744de9e41e7c22f60f261731270 to your computer and use it in GitHub Desktop.
Save JulienPascal/87c2e744de9e41e7c22f60f261731270 to your computer and use it in GitHub Desktop.
CNN with Julia
# Loading packages and data
# See: https://github.com/FluxML/model-zoo/blob/master/vision/mnist/conv.jl
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Printf, BSON
using ImageView
using Plots
# Load labels and images from Flux.Data.MNIST
# Train set: images used to estimate the CNN
# Load data on gpu (if enabled)
train_labels = gpu.(MNIST.labels(:train))
train_imgs = gpu.(MNIST.images(:train));
# Test set: images used to see how well the CNN perform "out-of-the-sample"
test_imgs = MNIST.images(:test)
test_labels = MNIST.labels(:test)
print("Images in the train set: $(size(train_imgs))")
print("Images in the test set: $(size(test_imgs))")
# Visualization of one digit
NROWS, NCOLS = 28, 28
a = reshape(train_imgs[1], NROWS, NCOLS)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment