|
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, DiffEqSensitivity |
|
|
|
u0 = Float32[2.0; 0.0] |
|
datasize = 30 |
|
tspan = (0.0f0, 1.5f0) |
|
tsteps = range(tspan[1], tspan[2], length = datasize) |
|
|
|
function trueODEfunc(du, u, p, t) |
|
true_A = [-0.1 2.0; -2.0 -0.1] |
|
du .= ((u.^3)'true_A)' |
|
end |
|
|
|
prob_trueode = ODEProblem(trueODEfunc, u0, tspan) |
|
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) |
|
|
|
dudt2 = FastChain((x, p) -> x.^3, |
|
FastDense(2, 50, tanh), |
|
FastDense(50, 2)) |
|
neural_ode_f(u,p,t) = dudt2(u,p) |
|
pinit = initial_params(dudt2) |
|
prob = ODEProblem(neural_ode_f, u0, tspan, pinit) |
|
|
|
function predict_neuralode(p) |
|
tmp_prob = remake(prob,p=p) |
|
Array(solve(tmp_prob,Tsit5(),saveat=tsteps,sensealg=BacksolveAdjoint(autojacvec=ReverseDiffVJP(true)))) |
|
end |
|
|
|
function loss_neuralode(p) |
|
pred = predict_neuralode(p) |
|
loss = sum(abs2, ode_data .- pred) |
|
return loss, pred |
|
end |
|
|
|
callback = function (p, l, pred; doplot = true) |
|
#display(l) |
|
# plot current prediction against data |
|
#plt = scatter(tsteps, ode_data[1,:], label = "data") |
|
#scatter!(plt, tsteps, pred[1,:], label = "prediction") |
|
#if doplot |
|
# display(plot(plt)) |
|
#end |
|
return false |
|
end |
|
|
|
@time result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, pinit, |
|
ADAM(0.05), cb = callback, |
|
maxiters = 500) |
|
|
|
#= |
|
2.687161 seconds (17.79 M allocations: 1002.418 MiB, 7.41% gc time) |
|
|
|
* Status: success |
|
|
|
* Candidate solution |
|
Final objective value: 2.761669e-02 |
|
|
|
* Found with |
|
Algorithm: ADAM |
|
|
|
* Convergence measures |
|
|x - x'| = NaN ≰ 0.0e+00 |
|
|x - x'|/|x'| = NaN ≰ 0.0e+00 |
|
|f(x) - f(x')| = NaN ≰ 0.0e+00 |
|
|f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00 |
|
|g(x)| = NaN ≰ 0.0e+00 |
|
|
|
* Work counters |
|
Seconds run: 3 (vs limit Inf) |
|
Iterations: 500 |
|
f(x) calls: 500 |
|
∇f(x) calls: 500 |
|
=# |