Created
October 28, 2020 13:15
-
-
Save Libbum/6d96586ff9bd50536dffb2b03896b911 to your computer and use it in GitHub Desktop.
A universal ODE implementation of a Rössler Attractor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using OrdinaryDiffEq | |
using ModelingToolkit | |
using DataDrivenDiffEq | |
using LinearAlgebra | |
using DiffEqSensitivity | |
using Optim | |
using Flux | |
using DiffEqFlux | |
using Plots | |
# Bulid up a Rössler Attractor | |
function sys!(du, u, p, t) | |
X, Y, Z = u | |
a, b, c = p | |
du[1] = -Y - Z | |
du[2] = X + a * Y | |
du[3] = b + Z * (X - c) | |
end | |
# Just using a small amount of data to train our NNs later | |
datasize = 30 | |
tspan = (0.0f0, 3.0f0) | |
tsteps = range(tspan[1], tspan[2], length = datasize) | |
u0 = Float32[0.3, 0.5, 0.78] | |
p_ = Float32[0.2, 0.2, 5.7] | |
prob = ODEProblem(sys!, u0, tspan, p_) | |
solution = | |
solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, maxiters = 1e7, saveat = tsteps) | |
plot(solution, legend = false) | |
plot(solution, vars = (1, 2, 3)) | |
# Let's assume we don't know how Y behaves. | |
# Ideal data | |
X = Array(solution) | |
# Ideal derivatives | |
DX = Array(solution(solution.t, Val{1})) | |
# Add noise to the data | |
println("Generate noisy data") | |
Xₙ = X + Float32(1e-3) * randn(eltype(X), size(X)) | |
# Generate a NN with three inputs [x, y, z] and three outputs [NN1, NN2, NN3] | |
ann = FastChain(FastDense(3, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 3)) | |
p = initial_params(ann) | |
function dudt_(u, p, t) | |
X, Y, Z = u | |
a, b, c, = p_ | |
# we replace each of our y values with the result of the NN. | |
y = ann(u, p) | |
[y[1] - Z, X + y[2], b + Z * (X - c) + y[3]] | |
end | |
prob_nn = ODEProblem(dudt_, u0, tspan, p) | |
sol_nn = solve(prob_nn, Tsit5(), u0 = u0, p = p, saveat = solution.t) | |
# The true result against the initial, random guess of the NN | |
plot(solution) | |
scatter!(sol_nn) | |
function predict(θ) | |
Array(solve( | |
prob_nn, | |
Vern7(), | |
u0 = u0, | |
p = θ, | |
saveat = solution.t, | |
abstol = 1e-6, | |
reltol = 1e-6, | |
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP()), | |
)) | |
end | |
# No regularisation right now | |
function loss(θ) | |
pred = predict(θ) | |
sum(abs2, Xₙ .- pred), pred | |
end | |
# Test | |
loss(p) | |
const losses = [] | |
callback(θ, l, pred) = begin | |
push!(losses, l) | |
if length(losses) % 50 == 0 | |
println("Current loss after $(length(losses)) iterations: $(losses[end])") | |
end | |
false | |
end | |
# First train with ADAM for better convergence | |
res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.01), cb = callback, maxiters = 200) | |
# Train with BFGS | |
res2 = DiffEqFlux.sciml_train( | |
loss, | |
res1.minimizer, | |
BFGS(initial_stepnorm = 0.01), | |
cb = callback, | |
maxiters = 10000, | |
) | |
println("Final training loss after $(length(losses)) iterations: $(losses[end])") | |
# Plot the losses | |
plot( | |
losses, | |
yaxis = :log10, | |
xaxis = :log10, | |
xlabel = "Iterations", | |
ylabel = "Loss", | |
legend = false, | |
) | |
# Plot the data and the approximation | |
NNsolution = predict(res2.minimizer) | |
# Trained on noisy data vs real solution | |
plot(solution.t, NNsolution') | |
scatter!(solution.t, X') | |
# The learned derivatives | |
prob_nn2 = ODEProblem(dudt_, u0, tspan, res2.minimizer) | |
_sol = solve(prob_nn2, Tsit5()) | |
DX_ = Array(_sol(solution.t, Val{1})) | |
plot(DX') | |
plot!(DX_') | |
# Ideal data for our missing parameters | |
L̄ = [-X[2, :]'; p_[1] * X[2, :]'; zeros(size(X, 2))'] | |
# NN guess | |
L̂ = ann(Xₙ, res2.minimizer) | |
scatter(L̄') | |
plot!(L̂') | |
# Plot the error | |
scatter(abs.(L̄ - L̂)', yaxis = :log10) | |
## Sparse Identification | |
# Create a Basis | |
@variables x y z | |
u = Operation[x; y; z] | |
polys = Operation[] | |
for i in 0:4 | |
for j in 0:i | |
for k in 0:j | |
push!(polys, u[1]^i * u[2]^j * u[3]^k) | |
push!(polys, u[2]^i * u[3]^j * u[1]^k) | |
push!(polys, u[3]^i * u[1]^j * u[2]^k) | |
end | |
end | |
end | |
basis = Basis(polys, u) | |
# Create an optimizer for the SINDy problem | |
opt = SR3(1e-2) | |
# Test on uode derivative data | |
println("SINDy on learned, partial, available data") | |
Ψ = SINDy(Xₙ, L̂, basis, opt; maxiter = 10000, normalize = true, denoise = true) | |
println(Ψ) | |
print_equations(Ψ) | |
# Extract the parameter | |
p̂ = parameters(Ψ) | |
println("First parameter guess : $(p̂)") | |
# This currently doesn't work due to an open issue I'm working on with | |
# the maintainer: | |
# https://github.com/SciML/DataDrivenDiffEq.jl/issues/164 | |
# The result is still the correct one, but since we find nothing for | |
# the third result, the current implementation will not allow us to build | |
# this step automatically. I therefore do it manually below. | |
# Don't run this portion of the code, start after the second horizontal line | |
# until the issue is fixed | |
# ---------------------------------------------------------------------------- | |
# The parameters are a bit off, but the equations are recovered | |
# Start another SINDy run to get closer to the ground truth | |
# Create function | |
unknown_sys = ODESystem(Ψ) | |
unknown_eq = ODEFunction(unknown_sys) | |
# Just the equations | |
b = Basis((u, p, t) -> unknown_eq(u, [1.0; 1.0], t), u) | |
# Retune for better parameters -> we could also use DiffEqFlux or other parameter estimation tools here. | |
Ψf = SINDy( | |
Xₙ[:, 2:end], | |
L̂[:, 2:end], | |
b, | |
STRRidge(0.01), | |
maxiter = 100, | |
convergence_error = 1e-18, | |
) # Succeed | |
println(Ψf) | |
p̂ = parameters(Ψf) | |
println("Second parameter guess : $(p̂)") | |
## As an alternate: | |
@parameters t p[1:3] | |
@variables x(t) y(t) z(t) | |
@derivatives D'~t | |
eqs = [D(x) ~ p[1] * y, D(y) ~ p[2] * y, D(z) ~ p[3] * y] #Fake, expect 0 | |
unknown_sys = ODESystem(eqs) | |
unknown_eq = ODEFunction(unknown_sys) | |
b = Basis((u, p, t) -> unknown_eq(u, [1.0; 1.0], t), u) | |
Ψf = SINDy( | |
Xₙ[:, 2:end], | |
L̂[:, 2:end], | |
b, | |
STRRidge(0.01), | |
maxiter = 100, | |
convergence_error = 1e-18, | |
) | |
# Results in 2 non-NaN numbers, which we manually update below | |
# ---------------------------------------------------------------------------- | |
p̂[1] = -0.9994146 | |
p̂[2] = 0.20148616 | |
# Create function manually | |
@parameters t p[1:2] | |
@variables x(t) y(t) z(t) | |
@derivatives D'~t | |
eqs = [D(x) ~ p[1] * y, D(y) ~ p[2] * y] | |
recovered_sys = ODESystem(eqs) | |
recovered_eq = ODEFunction(recovered_sys) | |
## Continue | |
# Now we can place the recovered mechanisms into our equation and observe the results | |
function dudt(u, p, t) | |
X, Y, Z = u | |
a, b, c, = p_ | |
y = recovered_eq(u, p, t) | |
#TODO: The result is flipped? Why? | |
[y[2] - Z, X + y[1], b + Z * (X - c)] | |
end | |
approximate_prob = ODEProblem(dudt, u0, tspan, p̂) | |
approximate_solution = solve(approximate_prob, Tsit5(), saveat = 0.01) | |
# Plot | |
plot(solution) | |
plot!(approximate_solution) | |
## Simulation | |
# Look at long term prediction | |
t_long = (0.0, 50.0) | |
approximate_prob = ODEProblem(dudt, u0, t_long, p̂) | |
approximate_solution_long = solve(approximate_prob, Tsit5(), saveat = 0.1) | |
plot(approximate_solution_long) | |
true_prob = ODEProblem(sys!, u0, t_long, p_) | |
true_solution_long = solve(true_prob, Tsit5(), saveat = approximate_solution_long.t) | |
plot!(true_solution_long) | |
c1 = RGBA(116 / 255, 206 / 255, 227 / 255, 1) # LBlue | |
c2 = RGBA(31 / 255, 120 / 255, 180 / 255, 1) # Blue | |
c3 = RGBA(178 / 255, 223 / 255, 138 / 255, 1) # LGreen | |
c4 = RGBA(51 / 255, 160 / 255, 44 / 255, 1) # Green | |
c5 = RGBA(251 / 255, 154 / 255, 153 / 255, 1) #LRed | |
c6 = RGBA(227 / 255, 26 / 255, 28 / 255, 1) #Red | |
p1 = plot( | |
0.1:0.1:tspan[end], | |
abs.(Array(solution) .- NNsolution)' .+ eps(Float32), | |
lw = 3, | |
yaxis = :log, | |
title = "Timeseries of UODE Error", | |
color = [c2 c4 c6], | |
xlabel = "t", | |
label = ["x(t)" "y(t)" "z(t)"], | |
titlefont = "Helvetica", | |
legendfont = "Helvetica", | |
legend = :bottomright, | |
) | |
# Plot L₂ | |
p2 = plot( | |
X[1, :], | |
X[2, :], | |
L̂[2, :], | |
lw = 3, | |
title = "Neural Network Fit of U2(t)", | |
color = c2, | |
label = "Neural Network", | |
xaxis = "x", | |
yaxis = "y", | |
titlefont = "Helvetica", | |
legendfont = "Helvetica", | |
legend = :bottomright, | |
) | |
plot!(X[1, :], X[2, :], L̄[2, :], lw = 3, label = "True Missing Term", color = c1) | |
p3 = scatter( | |
solution, | |
color = [c1 c3 c5], | |
label = :none, | |
titlefont = "Helvetica", | |
legendfont = "Helvetica", | |
markersize = 3, | |
legend = :topleft, | |
msc = :auto, | |
) | |
plot!( | |
p3, | |
true_solution_long, | |
color = [c1 c3 c5], | |
linestyle = :dot, | |
lw = 3, | |
label = ["True x(t)" "True y(t)" "True z(t)"], | |
) | |
plot!( | |
p3, | |
approximate_solution_long, | |
color = [c2 c4 c6], | |
lw = 1, | |
label = ["Estimated x(t)" "Estimated y(t)" "Estimated z(t)"], | |
) | |
plot!(p3, [2.99, 3.01], [-10, 20], lw = 2, color = :black, label = :none) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment