Skip to content

Instantly share code, notes, and snippets.

@whilo
Last active April 17, 2019 23:20
Show Gist options
  • Save whilo/a37e587b54457cae3ac80d4926932ee6 to your computer and use it in GitHub Desktop.
Save whilo/a37e587b54457cae3ac80d4926932ee6 to your computer and use it in GitHub Desktop.
CNF playground with nested Jacobian not working.
using DifferentialEquations
using Distributions
using Flux, DiffEqFlux
using Flux.Tracker
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
# patch broken jacobian from tracker
function jacobian2(m,xp)
#xp = param(x)
x = [xp.data for xp in xp]
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,k,n)
for i = 1:k
back!(y[i], once = false) # Populate gradient accumulator
J[i,:] = xp.tracker.grad
#xp.tracker.grad .= 0 # Reset gradient accumulator
end
J
end
function cnf(du,u,p,t)
z, logpz = u
α, β = p
du[1] = f(z, p)
#du[2] = -sum(jacobian2((z)->f(z, p), [z]))
du[2] = -(1-tanh(α*z + β)^2)*α # manual
end
function predict_rd(x)
u0 = [x, 0.0]
tspan = (0.0, 10.0)
prob = ODEProblem(cnf,u0,tspan,p)
diffeq_rd(p,prob,Tsit5(),saveat=0.1)
end
function loss_rd(xs)
pz = Normal(0.0, 1.0)
preds = [predict_rd(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
p = param([0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
Flux.train!(loss_rd, params, data, opt)
# check whether it looks standard normal
using Plots
preds = [predict_rd(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds])
# plot traces of flow
trajs = [predict_rd(raw_data[1][i]) for i in 1:100]
plot(trajs[1].t, [[u[1].data for u in traj.u] for traj in trajs])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment