Skip to content

Instantly share code, notes, and snippets.

@sdejean28
Last active January 23, 2021 16:50
Show Gist options
  • Save sdejean28/2834363394de8017f1f816a0411754ca to your computer and use it in GitHub Desktop.
Save sdejean28/2834363394de8017f1f816a0411754ca to your computer and use it in GitHub Desktop.
using Plots
using Flux
using ColorSchemes
using NNlib
using Flux: @epochs
loss(x, y) = sum((m(x).-y).^2)
opt = Descent(0.01)
dataset = [([0.8], [1.0]),
([2.0], [3.0]),
([2.4], [2.0]),
([0], [0.5]),
([1.5], [2]),
([3], [2.5]),
([4.0], [1.5])]
plot(dataset, seriestype = :scatter, legend = false)
N = 50
f(x) = m([x])[1]
for j in 1:N
m = Chain(Dense(1,1, gelu), Dense(1,1))
for i in 1:10*j Flux.train!(loss, Flux.params(m), dataset, opt) end
if j < N
plot!(f, 0, 5, linecolor = get(ColorSchemes.Blues_8, j/N));
else
plot!(f, 0, 5, linecolor = get(ColorSchemes.Reds_6, 0.8), linewidth = 5);
end
end
display(plot!(dataset, seriestype = :scatter, legend = false, fmt = :png))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment