Skip to content

Instantly share code, notes, and snippets.

@Libbum
Created October 28, 2020 13:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Libbum/6d96586ff9bd50536dffb2b03896b911 to your computer and use it in GitHub Desktop.
Save Libbum/6d96586ff9bd50536dffb2b03896b911 to your computer and use it in GitHub Desktop.
A universal ODE implementation of a Rössler Attractor
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using DiffEqSensitivity
using Optim
using Flux
using DiffEqFlux
using Plots
# Bulid up a Rössler Attractor
function sys!(du, u, p, t)
X, Y, Z = u
a, b, c = p
du[1] = -Y - Z
du[2] = X + a * Y
du[3] = b + Z * (X - c)
end
# Just using a small amount of data to train our NNs later
datasize = 30
tspan = (0.0f0, 3.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
u0 = Float32[0.3, 0.5, 0.78]
p_ = Float32[0.2, 0.2, 5.7]
prob = ODEProblem(sys!, u0, tspan, p_)
solution =
solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, maxiters = 1e7, saveat = tsteps)
plot(solution, legend = false)
plot(solution, vars = (1, 2, 3))
# Let's assume we don't know how Y behaves.
# Ideal data
X = Array(solution)
# Ideal derivatives
DX = Array(solution(solution.t, Val{1}))
# Add noise to the data
println("Generate noisy data")
Xₙ = X + Float32(1e-3) * randn(eltype(X), size(X))
# Generate a NN with three inputs [x, y, z] and three outputs [NN1, NN2, NN3]
ann = FastChain(FastDense(3, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 3))
p = initial_params(ann)
function dudt_(u, p, t)
X, Y, Z = u
a, b, c, = p_
# we replace each of our y values with the result of the NN.
y = ann(u, p)
[y[1] - Z, X + y[2], b + Z * (X - c) + y[3]]
end
prob_nn = ODEProblem(dudt_, u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(), u0 = u0, p = p, saveat = solution.t)
# The true result against the initial, random guess of the NN
plot(solution)
scatter!(sol_nn)
function predict(θ)
Array(solve(
prob_nn,
Vern7(),
u0 = u0,
p = θ,
saveat = solution.t,
abstol = 1e-6,
reltol = 1e-6,
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP()),
))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
sum(abs2, Xₙ .- pred), pred
end
# Test
loss(p)
const losses = []
callback(θ, l, pred) = begin
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
false
end
# First train with ADAM for better convergence
res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.01), cb = callback, maxiters = 200)
# Train with BFGS
res2 = DiffEqFlux.sciml_train(
loss,
res1.minimizer,
BFGS(initial_stepnorm = 0.01),
cb = callback,
maxiters = 10000,
)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
# Plot the losses
plot(
losses,
yaxis = :log10,
xaxis = :log10,
xlabel = "Iterations",
ylabel = "Loss",
legend = false,
)
# Plot the data and the approximation
NNsolution = predict(res2.minimizer)
# Trained on noisy data vs real solution
plot(solution.t, NNsolution')
scatter!(solution.t, X')
# The learned derivatives
prob_nn2 = ODEProblem(dudt_, u0, tspan, res2.minimizer)
_sol = solve(prob_nn2, Tsit5())
DX_ = Array(_sol(solution.t, Val{1}))
plot(DX')
plot!(DX_')
# Ideal data for our missing parameters
L̄ = [-X[2, :]'; p_[1] * X[2, :]'; zeros(size(X, 2))']
# NN guess
L̂ = ann(Xₙ, res2.minimizer)
scatter(L̄')
plot!(L̂')
# Plot the error
scatter(abs.(L̄ - L̂)', yaxis = :log10)
## Sparse Identification
# Create a Basis
@variables x y z
u = Operation[x; y; z]
polys = Operation[]
for i in 0:4
for j in 0:i
for k in 0:j
push!(polys, u[1]^i * u[2]^j * u[3]^k)
push!(polys, u[2]^i * u[3]^j * u[1]^k)
push!(polys, u[3]^i * u[1]^j * u[2]^k)
end
end
end
basis = Basis(polys, u)
# Create an optimizer for the SINDy problem
opt = SR3(1e-2)
# Test on uode derivative data
println("SINDy on learned, partial, available data")
Ψ = SINDy(Xₙ, L̂, basis, opt; maxiter = 10000, normalize = true, denoise = true)
println(Ψ)
print_equations(Ψ)
# Extract the parameter
p̂ = parameters(Ψ)
println("First parameter guess : $(p̂)")
# This currently doesn't work due to an open issue I'm working on with
# the maintainer:
# https://github.com/SciML/DataDrivenDiffEq.jl/issues/164
# The result is still the correct one, but since we find nothing for
# the third result, the current implementation will not allow us to build
# this step automatically. I therefore do it manually below.
# Don't run this portion of the code, start after the second horizontal line
# until the issue is fixed
# ----------------------------------------------------------------------------
# The parameters are a bit off, but the equations are recovered
# Start another SINDy run to get closer to the ground truth
# Create function
unknown_sys = ODESystem(Ψ)
unknown_eq = ODEFunction(unknown_sys)
# Just the equations
b = Basis((u, p, t) -> unknown_eq(u, [1.0; 1.0], t), u)
# Retune for better parameters -> we could also use DiffEqFlux or other parameter estimation tools here.
Ψf = SINDy(
Xₙ[:, 2:end],
L̂[:, 2:end],
b,
STRRidge(0.01),
maxiter = 100,
convergence_error = 1e-18,
) # Succeed
println(Ψf)
p̂ = parameters(Ψf)
println("Second parameter guess : $(p̂)")
## As an alternate:
@parameters t p[1:3]
@variables x(t) y(t) z(t)
@derivatives D'~t
eqs = [D(x) ~ p[1] * y, D(y) ~ p[2] * y, D(z) ~ p[3] * y] #Fake, expect 0
unknown_sys = ODESystem(eqs)
unknown_eq = ODEFunction(unknown_sys)
b = Basis((u, p, t) -> unknown_eq(u, [1.0; 1.0], t), u)
Ψf = SINDy(
Xₙ[:, 2:end],
L̂[:, 2:end],
b,
STRRidge(0.01),
maxiter = 100,
convergence_error = 1e-18,
)
# Results in 2 non-NaN numbers, which we manually update below
# ----------------------------------------------------------------------------
p̂[1] = -0.9994146
p̂[2] = 0.20148616
# Create function manually
@parameters t p[1:2]
@variables x(t) y(t) z(t)
@derivatives D'~t
eqs = [D(x) ~ p[1] * y, D(y) ~ p[2] * y]
recovered_sys = ODESystem(eqs)
recovered_eq = ODEFunction(recovered_sys)
## Continue
# Now we can place the recovered mechanisms into our equation and observe the results
function dudt(u, p, t)
X, Y, Z = u
a, b, c, = p_
y = recovered_eq(u, p, t)
#TODO: The result is flipped? Why?
[y[2] - Z, X + y[1], b + Z * (X - c)]
end
approximate_prob = ODEProblem(dudt, u0, tspan, p̂)
approximate_solution = solve(approximate_prob, Tsit5(), saveat = 0.01)
# Plot
plot(solution)
plot!(approximate_solution)
## Simulation
# Look at long term prediction
t_long = (0.0, 50.0)
approximate_prob = ODEProblem(dudt, u0, t_long, p̂)
approximate_solution_long = solve(approximate_prob, Tsit5(), saveat = 0.1)
plot(approximate_solution_long)
true_prob = ODEProblem(sys!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = approximate_solution_long.t)
plot!(true_solution_long)
c1 = RGBA(116 / 255, 206 / 255, 227 / 255, 1) # LBlue
c2 = RGBA(31 / 255, 120 / 255, 180 / 255, 1) # Blue
c3 = RGBA(178 / 255, 223 / 255, 138 / 255, 1) # LGreen
c4 = RGBA(51 / 255, 160 / 255, 44 / 255, 1) # Green
c5 = RGBA(251 / 255, 154 / 255, 153 / 255, 1) #LRed
c6 = RGBA(227 / 255, 26 / 255, 28 / 255, 1) #Red
p1 = plot(
0.1:0.1:tspan[end],
abs.(Array(solution) .- NNsolution)' .+ eps(Float32),
lw = 3,
yaxis = :log,
title = "Timeseries of UODE Error",
color = [c2 c4 c6],
xlabel = "t",
label = ["x(t)" "y(t)" "z(t)"],
titlefont = "Helvetica",
legendfont = "Helvetica",
legend = :bottomright,
)
# Plot L₂
p2 = plot(
X[1, :],
X[2, :],
L̂[2, :],
lw = 3,
title = "Neural Network Fit of U2(t)",
color = c2,
label = "Neural Network",
xaxis = "x",
yaxis = "y",
titlefont = "Helvetica",
legendfont = "Helvetica",
legend = :bottomright,
)
plot!(X[1, :], X[2, :], L̄[2, :], lw = 3, label = "True Missing Term", color = c1)
p3 = scatter(
solution,
color = [c1 c3 c5],
label = :none,
titlefont = "Helvetica",
legendfont = "Helvetica",
markersize = 3,
legend = :topleft,
msc = :auto,
)
plot!(
p3,
true_solution_long,
color = [c1 c3 c5],
linestyle = :dot,
lw = 3,
label = ["True x(t)" "True y(t)" "True z(t)"],
)
plot!(
p3,
approximate_solution_long,
color = [c2 c4 c6],
lw = 1,
label = ["Estimated x(t)" "Estimated y(t)" "Estimated z(t)"],
)
plot!(p3, [2.99, 3.01], [-10, 20], lw = 2, color = :black, label = :none)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment