Skip to content

Instantly share code, notes, and snippets.

@AStupidBear
Created May 8, 2020 06:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AStupidBear/fee38820551e1381e7ebcdc00e775199 to your computer and use it in GitHub Desktop.
Save AStupidBear/fee38820551e1381e7ebcdc00e775199 to your computer and use it in GitHub Desktop.
using Flux, Random, Statistics
Random.seed!(1234)
x = randn(Float32, 10, 1, 100)
y = mean(x, dims = 1)
model = Chain(LSTM(10, 100), LSTM(100, 1))
function loss(x, y)
xs = Flux.unstack(x, 3)
ys = Flux.unstack(y, 3)
ŷs = model.(xs)
l = 0f0
for t in 1:length(ŷs)
l += Flux.mse(ys[t], ŷs[t])
end
return l / length(ŷs)
end
opt = ADAM(1e-3, (0.9, 0.999))
cb = () -> (println(Flux.data(loss(x, y))); Flux.reset!(model))
Flux.@epochs 10 Flux.train!(loss, params(model), repeat([(x, y)], 100), opt, cb = cb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment