# Define the callback function that prints the loss every epoch callback = () -> @show(loss(xtrn, ytrn)); # Train the model function accuracy(x, y) return sum(Flux.onecold(model(x')) .== Flux.onecold(y)) / size(y, 2) end err = hcat(Flux.Tracker.data(loss(xtrn, ytrn)), Flux.Tracker.data(loss(xtst, ytst))) acc = hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst)) for i in 1:100 Flux.train!(loss, Flux.params(model), minibatches, Flux.ADAM(), cb = Flux.throttle(callback, 1)); global err = vcat(err, hcat(Flux.Tracker.data(loss(xtrn, ytrn)), Flux.Tracker.data(loss(xtst, ytst)))) global acc = vcat(acc, hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst))) end # Save loss and to csv for visualization write("error-flux.csv", DataFrame(err, [:training, :testing])) write("accuracy-flux.csv", DataFrame(acc, [:training, :testing]))