Skip to content

Instantly share code, notes, and snippets.

@baggepinnen
Created November 2, 2017 13:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save baggepinnen/5da18d7abecec112f532269d24a127d0 to your computer and use it in GitHub Desktop.
Save baggepinnen/5da18d7abecec112f532269d24a127d0 to your computer and use it in GitHub Desktop.
This script causes a segfault, sometimes it as to be run twice to cause it. Julia v0.6.1
using Flux
using Flux: back!, truncate!, treelike, train!, mse
N = 200
n = 2
function generate_data()
A = randn(n,n)
A = expm(A - A')
A = 0.999A
A = [0.999 1; 0 0.8]
x0 = randn(n)
x = zeros(n,N)
x[:,1] = x0
for i = 1:N-1
if false #sum(x[:,i]) < 0
x[:,i+1] = -A*x[:,i] + 0.01randn(n)
else
x[:,i+1] = A*x[:,i] + 0.01randn(n)
end
end
y = x[:,2:end]
x = x[:,1:end-1]
x,y,A
end
function fit_model(m, loss, x, y, cb = () -> ())
dataset = Iterators.repeated((x, y), 2000)
function evalcallback()
# @show(loss(x, y))
# plot(y', ls = :dash, show=false, layout=n, size=(1500,600))
# plot!(m(x).data', ls = :solid, show=true)
end
opt = ADAM(params(m), 0.01)
train!(loss, dataset, opt, cb = cb)
loss(x,y).data[1]
end
x,y,A = generate_data()
# x = param(x)
np = 17
m = Chain(
Dense(n,np,swish),
Dense(np,np,swish),
Dense(np,np,swish),
Dense(np,n)
)
loss = (x,y) -> Flux.mse(m(x), y)
# dataset = Iterators.repeated((x, y), 2000)
using IterTools
x = 0*x; y = sign.(y)
batch_size = 3
epochs = 5
# dataset = repeated(zip(partition(x,batch_size),partition(y,batch_size)), epochs)
dataset = zip(collect(partition(x,batch_size)),collect(partition(y,batch_size)))
for xi in dataset
@show xi
end
opt = ADAM(params(m), 0.01)
train!(loss, dataset, opt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment