Skip to content

Instantly share code, notes, and snippets.

@iskyd
Created March 16, 2023 09:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iskyd/6d46f1d590467b0174dd31799de7897e to your computer and use it in GitHub Desktop.
Save iskyd/6d46f1d590467b0174dd31799de7897e to your computer and use it in GitHub Desktop.
Flux Cifar-10
begin
using Flux, MLDatasets, Statistics
using Flux: onehotbatch, onecold, logitcrossentropy, params
using Base.Iterators: partition
using Printf, BSON
using CUDA
using ImageShow
using Images
CUDA.allowscalar(true)
end
MLDatasets.CIFAR10().metadata["class_names"][MLDatasets.CIFAR10().targets[2] + 1], convert2image(MLDatasets.CIFAR10(), 2)
Base.@kwdef mutable struct TrainArgs
lr::Float64 = 3e-3
epochs::Int = 100
batch_size = 32
savepath::String = "./"
end
function make_minibatch(X, Y, idxs)
X_batch = Array{Float32}(undef, size(X)[1:end-1]..., length(idxs))
for i in 1:length(idxs)
X_batch[:, :, :, i] = Float32.(X[:,:,:,idxs[i]])
end
Y_batch = onehotbatch(Y[idxs], 0:9)
return (X_batch, Y_batch)
end
function get_processed_data(args)
train_imgs, train_labels = MLDatasets.CIFAR10(:train)[:]
mb_idxs = partition(1:length(train_labels), args.batch_size)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]
test_imgs, test_labels = MLDatasets.CIFAR10(:test)[:]
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_labels))
return train_set, test_set
end
begin
args = TrainArgs()
trainloader, testloader = get_processed_data(args)
end
function build_model(args; imgsize = (32,32,3), nclasses = 10)
model = Chain(
Conv((5,5), imgsize[end]=>16, relu),
MaxPool((2,2)),
Conv((5,5), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(200, 120),
Dense(120, 84),
Dense(84, 10))
model
end
begin
augment(x) = x .+ gpu(0.1f0*randn(eltype(x), size(x)))
anynan(x) = any(y -> any(isnan, y), x)
accuracy(x, y, model) = mean(Flux.onecold(cpu(model(x)), 0:9) .== Flux.onecold(cpu(y), 0:9))
end
function train(; kws...)
args = TrainArgs(; kws...)
@info("Loading data set")
train_set, test_set = get_processed_data(args)
@info("Building model...")
model = build_model(args)
# Load model and datasets onto GPU
train_set = gpu.(train_set)
test_set = gpu.(test_set)
model = gpu(model)
# Pre-compile the model
model(train_set[1][1])
# define the loss function
function loss(x, y)
x̂ = augment(x)
ŷ = model(x)
return logitcrossentropy(ŷ, y)
end
opt = Adam(args.lr)
@info("Beginning training loop...")
best_acc = 0.0
last_improvement = 0
for epoch_idx in 1:args.epochs
Flux.train!(loss, params(model), train_set, opt)
if anynan(Flux.params(model))
@error "NaN params"
break
end
@info("Computing accuracy...")
acc = accuracy(test_set..., model)
@info(@sprintf("[%d]: Test accuracy: %.4f", epoch_idx, acc))
if acc >= 0.999
@info(" -> Early-exiting: We reached our target accuracy of 99.9%")
break
end
# If this is the best accuracy we've seen so far, save the model out
if acc >= best_acc
@info(" -> New best accuracy! Saving model out to cifar_10.bson")
BSON.@save joinpath(args.savepath, "cifar_10.bson") params=cpu.(params(model)) epoch_idx acc
best_acc = acc
last_improvement = epoch_idx
end
# drop out the learning rate if no improvement
if epoch_idx - last_improvement >= 5 && opt.eta > 1e-6
opt.eta /= 10.0
@warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!")
last_improvement = epoch_idx
end
if epoch_idx - last_improvement >= 10
@warn(" -> model converged.")
break
end
end
end
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment