Created
March 16, 2023 09:02
Flux Cifar-10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
begin | |
using Flux, MLDatasets, Statistics | |
using Flux: onehotbatch, onecold, logitcrossentropy, params | |
using Base.Iterators: partition | |
using Printf, BSON | |
using CUDA | |
using ImageShow | |
using Images | |
CUDA.allowscalar(true) | |
end | |
MLDatasets.CIFAR10().metadata["class_names"][MLDatasets.CIFAR10().targets[2] + 1], convert2image(MLDatasets.CIFAR10(), 2) | |
Base.@kwdef mutable struct TrainArgs | |
lr::Float64 = 3e-3 | |
epochs::Int = 100 | |
batch_size = 32 | |
savepath::String = "./" | |
end | |
function make_minibatch(X, Y, idxs) | |
X_batch = Array{Float32}(undef, size(X)[1:end-1]..., length(idxs)) | |
for i in 1:length(idxs) | |
X_batch[:, :, :, i] = Float32.(X[:,:,:,idxs[i]]) | |
end | |
Y_batch = onehotbatch(Y[idxs], 0:9) | |
return (X_batch, Y_batch) | |
end | |
function get_processed_data(args) | |
train_imgs, train_labels = MLDatasets.CIFAR10(:train)[:] | |
mb_idxs = partition(1:length(train_labels), args.batch_size) | |
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs] | |
test_imgs, test_labels = MLDatasets.CIFAR10(:test)[:] | |
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_labels)) | |
return train_set, test_set | |
end | |
begin | |
args = TrainArgs() | |
trainloader, testloader = get_processed_data(args) | |
end | |
function build_model(args; imgsize = (32,32,3), nclasses = 10) | |
model = Chain( | |
Conv((5,5), imgsize[end]=>16, relu), | |
MaxPool((2,2)), | |
Conv((5,5), 16=>8, relu), | |
MaxPool((2,2)), | |
x -> reshape(x, :, size(x, 4)), | |
Dense(200, 120), | |
Dense(120, 84), | |
Dense(84, 10)) | |
model | |
end | |
begin | |
augment(x) = x .+ gpu(0.1f0*randn(eltype(x), size(x))) | |
anynan(x) = any(y -> any(isnan, y), x) | |
accuracy(x, y, model) = mean(Flux.onecold(cpu(model(x)), 0:9) .== Flux.onecold(cpu(y), 0:9)) | |
end | |
function train(; kws...) | |
args = TrainArgs(; kws...) | |
@info("Loading data set") | |
train_set, test_set = get_processed_data(args) | |
@info("Building model...") | |
model = build_model(args) | |
# Load model and datasets onto GPU | |
train_set = gpu.(train_set) | |
test_set = gpu.(test_set) | |
model = gpu(model) | |
# Pre-compile the model | |
model(train_set[1][1]) | |
# define the loss function | |
function loss(x, y) | |
x̂ = augment(x) | |
ŷ = model(x) | |
return logitcrossentropy(ŷ, y) | |
end | |
opt = Adam(args.lr) | |
@info("Beginning training loop...") | |
best_acc = 0.0 | |
last_improvement = 0 | |
for epoch_idx in 1:args.epochs | |
Flux.train!(loss, params(model), train_set, opt) | |
if anynan(Flux.params(model)) | |
@error "NaN params" | |
break | |
end | |
@info("Computing accuracy...") | |
acc = accuracy(test_set..., model) | |
@info(@sprintf("[%d]: Test accuracy: %.4f", epoch_idx, acc)) | |
if acc >= 0.999 | |
@info(" -> Early-exiting: We reached our target accuracy of 99.9%") | |
break | |
end | |
# If this is the best accuracy we've seen so far, save the model out | |
if acc >= best_acc | |
@info(" -> New best accuracy! Saving model out to cifar_10.bson") | |
BSON.@save joinpath(args.savepath, "cifar_10.bson") params=cpu.(params(model)) epoch_idx acc | |
best_acc = acc | |
last_improvement = epoch_idx | |
end | |
# drop out the learning rate if no improvement | |
if epoch_idx - last_improvement >= 5 && opt.eta > 1e-6 | |
opt.eta /= 10.0 | |
@warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!") | |
last_improvement = epoch_idx | |
end | |
if epoch_idx - last_improvement >= 10 | |
@warn(" -> model converged.") | |
break | |
end | |
end | |
end | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment