Skip to content

Instantly share code, notes, and snippets.

@seyedaha
Created January 9, 2024 00:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seyedaha/85e7b10a37062239b55bb312a8eb6c9e to your computer and use it in GitHub Desktop.
Save seyedaha/85e7b10a37062239b55bb312a8eb6c9e to your computer and use it in GitHub Desktop.
Trying to embed a neural network in a MTK-based Hires ODE system
# SciML Tools
using DifferentialEquations, ModelingToolkit, SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
# Standard Libraries
using LinearAlgebra, Statistics
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
# Others
using ModelingToolkit
# Set a random seed for reproducible behaviour
rng = StableRNG(1111)
# %%-----------------------------------------
@variables begin
t
y1(t) = 1.0
y2(t) = 0.0
y3(t) = 0.0
y4(t) = 0.0
y5(t) = 0.0
y6(t) = 0.0
y7(t) = 0.0
y8(t) = 0.0057
end
@parameters begin
k1 = 1.71
k2 = 280.0
k3 = 8.32
k4 = 0.69
k5 = 0.43
k6 = 1.81
end
D = Differential(t)
eqs = [
D(y1) ~ -k1 * y1 + k5 * y2 + k3 * y3 + 0.0007,
D(y2) ~ k1 * y1 - 8.75y2,
D(y3) ~ -10.03y3 + k5 * y4 + 0.035y5,
D(y4) ~ k3 * y2 + k1 * y3 - 1.12y4,
D(y5) ~ -1.745y5 + k5 * y6 + k5 * y7,
D(y6) ~ -k2 * y6 * y8 + k4 * y4 + k1 * y5 - k5 * y6 + k4 * y7,
D(y7) ~ k2 * y6 * y8 - k6 * y7,
D(y8) ~ -k2 * y6 * y8 + k6 * y7
]
@named odesys = ODESystem(eqs)
odeprob = ODEProblem(odesys)
# %%-----------------------------------------
# extracting the RHS function of the ODE system
n = length(states(odesys))
ks = Dict(
k1 => 1.71,
k2 => 280.0,
k3 => 8.32,
k4 => 0.69,
k5 => 0.43,
k6 => 1.81,
)
rhs = map(el -> el.rhs, equations(odesys))
_rhs = [substitute(eq, ks) for eq in rhs]
rhs_f = build_function(_rhs, states(odesys), expression=Val{true})[1] |> eval
function hires!(du, u, p, t)
du .= rhs_f(u)
end
config = (
alg=Rodas5P(),
tspan=[0, 10],
saveat=0.1,
abstol=1e-8,
reltol=1e-8,
)
prob = ODEProblem(hires!, odeprob.u0, [0, 10], odeprob.p)
# This approach solves the ODE successfully:
sol = solve(prob; config...)
plot(sol)
# %%-----------------------------------------
# generating ground-truth data
function hires!(du, u, p, t)
y1, y2, y3, y4, y5, y6, y7, y8 = u
du[1] = -1.71*y1 + 0.43*y2 + 8.32*y3 + 0.0007
du[2] = 1.71*y1 - 8.75*y2
du[3] = -10.03*y3 + 0.43*y4 + 0.035*y5
du[4] = 8.32*y2 + 1.71*y3 - 1.12*y4
du[5] = -1.745*y5 + 0.43*y6 + 0.43*y7
du[6] = -280.0*y6*y8 + 0.69*y4 + 1.71*y5 - 0.43*y6 + 0.69*y7
du[7] = 280.0*y6*y8 - 1.81*y7
du[8] = -280.0*y6*y8 + 1.81*y7
end
# Define the experimental parameter
tspan = (0.0, 30.0)
u0 = odeprob.u0
p_ = odeprob.p
prob = ODEProblem(hires!, u0, tspan, p_)
solution = solve(prob, Rodas5P(), abstol=1e-8, reltol=1e-8, saveat=[collect(0:0.1:5);collect(5.3:0.5:30)])
# Add noise in terms of the mean
X = Array(solution)
t = solution.t
Xₙ = X
labels = reshape(string.(states(odesys)), 1, 8)
# %%-----------------------------------------
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain(
Lux.Dense(8, 20, tanh),
Lux.Dense(20, 20, tanh),
Lux.Dense(20, 8)
)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
_st = st
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, _st)[1] # Network prediction
du .= rhs_f .+ û
end
# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
# %%-----------------------------------------
function predict(θ, X=Xₙ[:, 1], T=t)
_prob = remake(prob_nn, u0=X, tspan=(T[1], T[end]), p=θ)
Array(solve(_prob, TRBDF2(), saveat=T,
abstol=1e-6, reltol=1e-6,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
))
end
function loss(θ)
X̂ = predict(θ)
loss = mean(abs2, (Xₙ .- X̂)) + 0.2sum(abs2, map(x->min(x, 0), X̂))
return loss, X̂
end
losses = Float64[]
gr()
callback = function (p, l, pred; doplot=true)
push!(losses, l)
if length(losses) % 1 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
# plot current prediction against data
if doplot
plt = scatter(t, Xₙ[7, :]; markersize=3, alpha=0.7, label="data")
plot!(plt, t, pred[7, :]; label="prediction")
display(plot(plt))
end
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters=500)
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment