Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@ToucheSir
Created September 17, 2021 16:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ToucheSir/072097e4d50147cdd97dac00c6622317 to your computer and use it in GitHub Desktop.
Save ToucheSir/072097e4d50147cdd97dac00c6622317 to your computer and use it in GitHub Desktop.
# Classifies MNIST digits with a convolutional network.
# Writes out saved model to the file "mnist_conv.bson".
# Demonstrates basic model construction, training, saving,
# conditional early-exit, and learning rate scheduling.
#
# This model, while simple, should hit around 99% test
# accuracy after training for approximately 20 epochs.
using Flux, MLDatasets, Random, Statistics
using Flux: unsqueeze, onehotbatch, onecold, logitcrossentropy
using Base.Iterators: partition
using Printf, BSON
using Parameters: @with_kw
using CUDA
if has_cuda()
@info "CUDA is on"
CUDA.allowscalar(false)
end
@with_kw mutable struct Args
lr::Float64 = 3e-3
epochs::Int = 20
batch_size = 128
savepath::String = "./"
end
# Bundle images together with labels and group into minibatches
function make_minibatch(X, Y, idxs)
X_batch = unsqueeze(X[:, :, idxs], 3)
Y_batch = onehotbatch(Y[idxs], 0:9)
return X_batch, Y_batch
end
function get_processed_data(args)
# Load labels and images from MLDatasets
train_imgs, train_labels = MNIST.traindata(Float32)
mb_idxs = partition(1:length(train_labels), args.batch_size)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]
# Prepare test set as one giant minibatch:
test_imgs, test_labels = MNIST.testdata(Float32)
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_labels))
return train_set, test_set
end
# Build model
function build_model(args; imgsize = (28,28,1), nclasses = 10)
cnn_output_size = Int.(floor.([imgsize[1]/8,imgsize[2]/8,32]))
return Chain(
# First convolution, operating upon a 28x28 image
Conv((3, 3), imgsize[3]=>16, pad=(1,1), relu),
MaxPool((2,2)),
# Second convolution, operating upon a 14x14 image
Conv((3, 3), 16=>32, pad=(1,1), relu),
MaxPool((2,2)),
# Third convolution, operating upon a 7x7 image
Conv((3, 3), 32=>32, pad=(1,1), relu),
MaxPool((2,2)),
# Reshape 3d tensor into a 2d one using `Flux.flatten`, at this point it should be (3, 3, 32, N)
flatten,
Dense(prod(cnn_output_size), nclasses))
end
# We augment `x` a little bit here, adding in random noise.
augment(x) = x .+ 0.1f0 * randn!(similar(x))
# Function to check if any element is NaN or not
anynan(m) = any(any(isnan, p) for p in params(m))
accuracy(x, y, model) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
function train(; kws...)
args = Args(; kws...)
@info("Loading data set")
train_set, test_set = get_processed_data(args)
# Define our model. We will use a simple convolutional architecture with
# three iterations of Conv -> ReLU -> MaxPool, followed by a final Dense layer.
@info("Building model...")
model = build_model(args)
# Load model and datasets onto GPU, if enabled
train_set = gpu.(train_set)
test_set = gpu.(test_set)
model = gpu(model)
ps = params(model)
# Make sure our model is nicely precompiled before starting our training loop
model(train_set[1][1])
# `loss()` calculates the crossentropy loss between our prediction `y_hat`
# (calculated from `model(x)`) and the ground truth `y`. We augment the data
# a bit, adding gaussian random noise to our image to make it more robust.
function loss(x, y)
x̂ = augment(x)
ŷ = model(x̂)
return logitcrossentropy(ŷ, y)
end
# Train our model with the given training set using the ADAM optimizer and
# printing out performance against the test set as we go.
opt = ADAM(args.lr)
@info "Beginning training loop..."
best_acc = 0.0
last_improvement = 0
for epoch_idx in 1:args.epochs
# Train for a single epoch
Flux.train!(loss, ps, train_set, opt)
# Terminate on NaN
if anynan(ps)
@error "NaN params"
break
end
# Calculate accuracy:
acc = accuracy(test_set..., model)
@info(@sprintf("[%d]: Test accuracy: %.4f", epoch_idx, acc))
# If our accuracy is good enough, quit out.
if acc >= 0.999
@info " -> Early-exiting: We reached our target accuracy of 99.9%"
break
end
# If this is the best accuracy we've seen so far, save the model out
if acc >= best_acc
@info " -> New best accuracy! Saving model out to mnist_conv.bson"
BSON.@save joinpath(args.savepath, "mnist_conv.bson") params=cpu.(params(model)) epoch_idx acc
best_acc = acc
last_improvement = epoch_idx
end
# If we haven't seen improvement in 5 epochs, drop our learning rate:
if epoch_idx - last_improvement >= 5 && opt.eta > 1e-6
opt.eta /= 10.0
@warn " -> Haven't improved in a while, dropping learning rate to $(opt.eta)!"
# After dropping learning rate, give it a few epochs to improve
last_improvement = epoch_idx
end
if epoch_idx - last_improvement >= 10
@warn " -> We're calling this converged."
break
end
end
end
# Testing the model, from saved model
function test(; kws...)
args = Args(; kws...)
# Loading the test data
_, test_set = get_processed_data(args)
# Re-constructing the model with random initial weights
model = build_model(args)
# Loading the saved parameters
BSON.@load joinpath(args.savepath, "mnist_conv.bson") params
# Loading parameters onto the model
Flux.loadparams!(model, params)
test_set = gpu.(test_set)
model = gpu(model)
@show accuracy(test_set...,model)
end
cd(@__DIR__)
train()
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment