-
-
Save seyedaha/85e7b10a37062239b55bb312a8eb6c9e to your computer and use it in GitHub Desktop.
Trying to embed a neural network in a MTK-based Hires ODE system
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SciML Tools | |
using DifferentialEquations, ModelingToolkit, SciMLSensitivity | |
using Optimization, OptimizationOptimisers, OptimizationOptimJL | |
# Standard Libraries | |
using LinearAlgebra, Statistics | |
# External Libraries | |
using ComponentArrays, Lux, Zygote, Plots, StableRNGs | |
# Others | |
using ModelingToolkit | |
# Set a random seed for reproducible behaviour | |
rng = StableRNG(1111) | |
# %%----------------------------------------- | |
@variables begin | |
t | |
y1(t) = 1.0 | |
y2(t) = 0.0 | |
y3(t) = 0.0 | |
y4(t) = 0.0 | |
y5(t) = 0.0 | |
y6(t) = 0.0 | |
y7(t) = 0.0 | |
y8(t) = 0.0057 | |
end | |
@parameters begin | |
k1 = 1.71 | |
k2 = 280.0 | |
k3 = 8.32 | |
k4 = 0.69 | |
k5 = 0.43 | |
k6 = 1.81 | |
end | |
D = Differential(t) | |
eqs = [ | |
D(y1) ~ -k1 * y1 + k5 * y2 + k3 * y3 + 0.0007, | |
D(y2) ~ k1 * y1 - 8.75y2, | |
D(y3) ~ -10.03y3 + k5 * y4 + 0.035y5, | |
D(y4) ~ k3 * y2 + k1 * y3 - 1.12y4, | |
D(y5) ~ -1.745y5 + k5 * y6 + k5 * y7, | |
D(y6) ~ -k2 * y6 * y8 + k4 * y4 + k1 * y5 - k5 * y6 + k4 * y7, | |
D(y7) ~ k2 * y6 * y8 - k6 * y7, | |
D(y8) ~ -k2 * y6 * y8 + k6 * y7 | |
] | |
@named odesys = ODESystem(eqs) | |
odeprob = ODEProblem(odesys) | |
# %%----------------------------------------- | |
# extracting the RHS function of the ODE system | |
n = length(states(odesys)) | |
ks = Dict( | |
k1 => 1.71, | |
k2 => 280.0, | |
k3 => 8.32, | |
k4 => 0.69, | |
k5 => 0.43, | |
k6 => 1.81, | |
) | |
rhs = map(el -> el.rhs, equations(odesys)) | |
_rhs = [substitute(eq, ks) for eq in rhs] | |
rhs_f = build_function(_rhs, states(odesys), expression=Val{true})[1] |> eval | |
function hires!(du, u, p, t) | |
du .= rhs_f(u) | |
end | |
config = ( | |
alg=Rodas5P(), | |
tspan=[0, 10], | |
saveat=0.1, | |
abstol=1e-8, | |
reltol=1e-8, | |
) | |
prob = ODEProblem(hires!, odeprob.u0, [0, 10], odeprob.p) | |
# This approach solves the ODE successfully: | |
sol = solve(prob; config...) | |
plot(sol) | |
# %%----------------------------------------- | |
# generating ground-truth data | |
function hires!(du, u, p, t) | |
y1, y2, y3, y4, y5, y6, y7, y8 = u | |
du[1] = -1.71*y1 + 0.43*y2 + 8.32*y3 + 0.0007 | |
du[2] = 1.71*y1 - 8.75*y2 | |
du[3] = -10.03*y3 + 0.43*y4 + 0.035*y5 | |
du[4] = 8.32*y2 + 1.71*y3 - 1.12*y4 | |
du[5] = -1.745*y5 + 0.43*y6 + 0.43*y7 | |
du[6] = -280.0*y6*y8 + 0.69*y4 + 1.71*y5 - 0.43*y6 + 0.69*y7 | |
du[7] = 280.0*y6*y8 - 1.81*y7 | |
du[8] = -280.0*y6*y8 + 1.81*y7 | |
end | |
# Define the experimental parameter | |
tspan = (0.0, 30.0) | |
u0 = odeprob.u0 | |
p_ = odeprob.p | |
prob = ODEProblem(hires!, u0, tspan, p_) | |
solution = solve(prob, Rodas5P(), abstol=1e-8, reltol=1e-8, saveat=[collect(0:0.1:5);collect(5.3:0.5:30)]) | |
# Add noise in terms of the mean | |
X = Array(solution) | |
t = solution.t | |
Xₙ = X | |
labels = reshape(string.(states(odesys)), 1, 8) | |
# %%----------------------------------------- | |
rbf(x) = exp.(-(x .^ 2)) | |
# Multilayer FeedForward | |
U = Lux.Chain( | |
Lux.Dense(8, 20, tanh), | |
Lux.Dense(20, 20, tanh), | |
Lux.Dense(20, 8) | |
) | |
# Get the initial parameters and state variables of the model | |
p, st = Lux.setup(rng, U) | |
_st = st | |
# Define the hybrid model | |
function ude_dynamics!(du, u, p, t, p_true) | |
û = U(u, p, _st)[1] # Network prediction | |
du .= rhs_f .+ û | |
end | |
# Closure with the known parameter | |
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_) | |
# Define the problem | |
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p) | |
# %%----------------------------------------- | |
function predict(θ, X=Xₙ[:, 1], T=t) | |
_prob = remake(prob_nn, u0=X, tspan=(T[1], T[end]), p=θ) | |
Array(solve(_prob, TRBDF2(), saveat=T, | |
abstol=1e-6, reltol=1e-6, | |
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)) | |
)) | |
end | |
function loss(θ) | |
X̂ = predict(θ) | |
loss = mean(abs2, (Xₙ .- X̂)) + 0.2sum(abs2, map(x->min(x, 0), X̂)) | |
return loss, X̂ | |
end | |
losses = Float64[] | |
gr() | |
callback = function (p, l, pred; doplot=true) | |
push!(losses, l) | |
if length(losses) % 1 == 0 | |
println("Current loss after $(length(losses)) iterations: $(losses[end])") | |
# plot current prediction against data | |
if doplot | |
plt = scatter(t, Xₙ[7, :]; markersize=3, alpha=0.7, label="data") | |
plot!(plt, t, pred[7, :]; label="prediction") | |
display(plot(plt)) | |
end | |
end | |
return false | |
end | |
adtype = Optimization.AutoZygote() | |
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) | |
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p)) | |
res1 = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters=500) | |
optprob2 = Optimization.OptimizationProblem(optf, res1.u) | |
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=300) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment