Skip to content

Instantly share code, notes, and snippets.

@caseykneale
Created May 15, 2020 23:57
Show Gist options
  • Save caseykneale/6736e3f0a33bd1fdf676ebac78a6debf to your computer and use it in GitHub Desktop.
Save caseykneale/6736e3f0a33bd1fdf676ebac78a6debf to your computer and use it in GitHub Desktop.
oldschool example new school api
using DiffEqFlux, OrdinaryDiffEq, Flux, MLDataUtils, NNlib
using Flux: logitcrossentropy
using MLDatasets: MNIST
function loadmnist(batchsize = bs)
# Use MLDataUtils LabelEnc for natural onehot conversion
onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
# Load MNIST
imgs, labels_raw = MNIST.traindata();
# Process images into (H,W,C,BS) batches
x_train = reshape(imgs,size(imgs,1),size(imgs,2),1,size(imgs,3))|>gpu
x_train = batchview(x_train,batchsize);
# Onehot and batch the labels
y_train = onehot(labels_raw)|>gpu
y_train = batchview(y_train,batchsize)
return x_train, y_train
end
# Main
const bs = 128
x_train, y_train = loadmnist(bs)
down = Chain(x->reshape(x,(28*28,:)),
Dense(784,20,tanh)
)
nfe = 0
nn = Chain(#(x,p) -> x,
Dense(20,10,tanh),
Dense(10,10,tanh),
Dense(10,20,tanh)
)
fc = Chain( Dense(20,10) )
nn_ode = NeuralODE( nn, (0.f0, 1.f0), Tsit5(),
save_everystep = false,
reltol = 1e-3, abstol = 1e-3,
save_start = false )
function DiffEqArray_to_Array( x )
xarr = Array( x )
return reshape( xarr, size(xarr)[1:2] )
end
m = Chain( down,
nn_ode,
DiffEqArray_to_Array,
fc )
m_no_ode = Chain( down, nn, fc)
x_d = down( x_train[1] )
nn_ode(x_d)
# Showing this works
x_m = m(x_train[1])
x_m = m_no_ode(x_train[1])
classify(x) = argmax.(eachcol(x))
function accuracy(model,data; n_batches=100)
total_correct = 0
total = 0
for (x,y) in collect(data)[1:n_batches]
target_class = classify(cpu(y))
predicted_class = classify(cpu(model(x)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct/total
end
#burn in accuracy
accuracy(m, zip(x_train,y_train))
loss(x,y) = logitcrossentropy(m(x),y)
#burn in loss
loss(x_train[1],y_train[1])
opt = ADAM(0.05)
iter = 0
cb() = begin
global iter += 1
@show iter
@show loss(x_train[1],y_train[1])
#@show sum(down[2].W) #Updates
#@show cpu(fc)[1].W[1] #Updates
#@show nn_ode.p[1] #Updates
(iter%10 == 0) && @show accuracy(m, zip(x_train,y_train))
global nfe=0
end
# res1 = DiffEqFlux.sciml_train( loss, params( down, nn_ode.p, fc ),
# opt, zip( x_train, y_train ),
# cb = cb, maxiters = 10000)
Flux.train!( loss, params( down, nn_ode.p, fc),
zip( x_train, y_train ), opt, cb = cb )
#cb(res1.minimizer, loss(res1.minimizer)...;doplot=true)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment