Skip to content

Instantly share code, notes, and snippets.

@alstat
Created March 9, 2018 16:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alstat/b12ac7e6f2837e41fd7f20546607da6b to your computer and use it in GitHub Desktop.
Save alstat/b12ac7e6f2837e41fd7f20546607da6b to your computer and use it in GitHub Desktop.
using Images
using MLDatasets
using Flux: ADAM,
argmax,
Chain,
crossentropy,
Dense,
params,
relu,
softmax,
throttle,
train!
const N = 2000
const EPOCHS = 30
x_train, y_train_coarse, y_train_fine = CIFAR100.traindata();
x_vec = [float.(reshape(x_train[:, :, :, i], :)) for i in 1:N];
X = hcat(x_vec...);
Y = onehotbatch(y_train_fine[1:N], 0:99);
model = Chain(
Dense(32^2 * 3, 32 * 10, relu),
Dense(32 * 10, 100),
softmax
);
loss(x, y) = crossentropy(model(x), y);
accuracy(x, y) = mean(argmax(model(x)) .== argmax(y));
dataset = Base.Iterators.repeated((X, Y), EPOCHS);
evalcb = () -> @show(loss(X, Y));
opt = SGD(params(model));
train!(loss, dataset, opt, cb = throttle(evalcb, 10));
accuracy(X, Y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment