Skip to content

Instantly share code, notes, and snippets.

@ChrisRackauckas
Created November 25, 2019 00:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChrisRackauckas/56e168271f42ba048c79aac2bd88af4f to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/56e168271f42ba048c79aac2bd88af4f to your computer and use it in GitHub Desktop.
Animation of neural ordinary differential equations with DiffEqFlux.jl
using DiffEqFlux, OrdinaryDiffEq, Flux, Plots
# Generate data from a real ODE
u0 = Float32[2.; 0.]; datasize = 30
tspan = (0.0f0,1.5f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))
# Define a Neural ODE
dudt = Chain(x -> x.^3,
Dense(2,75,tanh),
Dense(75,2))
n_ode(x) = neural_ode(dudt,x,tspan,AutoTsit5(Rosenbrock23(autodiff=false)),saveat=t,reltol=1e-7,abstol=1e-9)
function predict_n_ode()
n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())
# Train the Neural ODE to match the data
data = Iterators.repeated((), 200)
cb = function () #callback function to observe training
display(loss_n_ode()); cur_pred = Flux.data(predict_n_ode())
p1 = scatter(t,ode_data[1,:],label="data",legend=:bottomright); scatter!(p1,t,cur_pred[1,:],label="prediction")
p2 = scatter(t,ode_data[2,:],label="data",legend=:top); scatter!(p2,t,cur_pred[2,:],label="prediction")
display(plot(p1,p2,layout=(2,1)))
end
Flux.train!(loss_n_ode, Flux.params(dudt), data, Nesterov(0.0005), cb = cb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment