# Define the callback function that prints the loss every epoch
callback = () -> @show(loss(xtrn, ytrn));

# Train the model
function accuracy(x, y)
    return sum(Flux.onecold(model(x')) .== Flux.onecold(y)) / size(y, 2)
end

err = hcat(Flux.Tracker.data(loss(xtrn, ytrn)), Flux.Tracker.data(loss(xtst, ytst)))
acc = hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst))
for i in 1:100
    Flux.train!(loss, Flux.params(model), minibatches, Flux.ADAM(), cb = Flux.throttle(callback, 1));
    global err = vcat(err, hcat(Flux.Tracker.data(loss(xtrn, ytrn)), Flux.Tracker.data(loss(xtst, ytst))))
    global acc = vcat(acc, hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst)))
end

# Save loss and to csv for visualization
write("error-flux.csv", DataFrame(err, [:training, :testing]))
write("accuracy-flux.csv", DataFrame(acc, [:training, :testing]))