Last active
May 27, 2022 12:31
-
-
Save visr/5c5f5812552b6e261c2648c87233e9bd to your computer and use it in GitHub Desktop.
Save and plot parameter history with ModelingToolkit.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example from https://mtk.sciml.ai/dev/tutorials/ode_modeling/ | |
# with a parameter τ that is updated through a callback. | |
# The parameter history is tracked, which allows us to create interpolating | |
# functions that show the correct history for parameters and observed variables | |
# that depend on variables. | |
# This works around limitations discussed here: | |
# https://discourse.julialang.org/t/update-parameters-in-modelingtoolkit-using-callbacks/63770 | |
# And was posted here: https://discourse.julialang.org/t/save-and-plot-parameter-history-with-modelingtoolkit-jl/81778 | |
# License is MIT | |
# Versions are ModelingToolkit v8.12 and CairoMakie v0.8 | |
using ModelingToolkit | |
using DifferentialEquations: solve | |
import DifferentialEquations as DE | |
using DiffEqCallbacks: PeriodicCallback | |
using Symbolics: getname | |
using SciMLBase | |
using CairoMakie | |
""" | |
ForwardFill(t, v) | |
Create a callable struct that will give a value from v on or after a given t. | |
There is a tolerance of 1e-4 for t to avoid narrowly missing the next timestep. | |
v = rand(21) | |
ff = ForwardFill(0:0.1:2, v) | |
ff(0.1) == v[2] | |
ff(0.1 - 1e-5) == v[2] | |
ff(0.1 - 1e-3) == v[1] | |
""" | |
struct ForwardFill{T, V} | |
t::T | |
v::V | |
function ForwardFill(t::T, v::V) where {T, V} | |
n = length(t) | |
if n != length(v) | |
error("ForwardFill vectors are not of equal length") | |
end | |
if !issorted(t) | |
error("ForwardFill t is not sorted") | |
end | |
new{T, V}(t, v) | |
end | |
end | |
"Interpolate into a forward filled timeseries at t" | |
function (ff::ForwardFill{T, V})(t)::eltype(V) where {T, V} | |
# Subtract a small amount to avoid e.g. t = 2.999999s not picking up the t = 3s value. | |
# This can occur due to floating point issues with the calculated t::Float64 | |
# The offset is larger than the eps of 1 My in seconds, and smaller than the periodic | |
# callback interval. | |
i = searchsortedlast(ff.t, t + 1e-4) | |
i == 0 && throw(DomainError(t, "Requesting t before start of series.")) | |
return ff.v[i] | |
end | |
"Interpolate and get the index j of the result, useful for V=Vector{Vector{Float64}}" | |
function (ff::ForwardFill{T, V})(t, j)::eltype(eltype(V)) where {T, V} | |
i = searchsortedlast(ff.t, t + 1e-4) | |
i == 0 && throw(DomainError(t, "Requesting t before start of series.")) | |
return ff.v[i][j] | |
end | |
"Update a parameter value" | |
function set_param!(integrator, sys, param, x::Real)::Real | |
p = integrator.p | |
param = getname(param)::Symbol | |
params = getname.(parameters(sys)) | |
i = findfirst(==(param), params) | |
i === nothing && error("Parameter not found: $param") | |
return p[i] = x | |
end | |
"Add an entry to the parameter history" | |
function save!(param_hist::ForwardFill, t::Float64, p::Vector{Float64}) | |
push!(param_hist.t, t) | |
push!(param_hist.v, copy(p)) | |
return param_hist | |
end | |
""" | |
interpolator(sys, integrator, param_hist::ForwardFill, sym)::Function | |
Return a time interpolating function for the given a Symbolic or Symbol. | |
""" | |
function interpolator(sys, integrator, param_hist::ForwardFill, sym)::Function | |
sol = integrator.sol | |
# convert to Symbol to allow sym::Union{Symbolic, Num} input | |
sym = getname(sym)::Symbol | |
i = findfirst(==(sym), getname.(states(sys))) | |
if i !== nothing | |
# use solution as normal | |
return t -> sol(t, idxs = i) | |
end | |
i = findfirst(==(sym), getname.(parameters(sys))) | |
if i !== nothing | |
# use param_hist | |
return t -> param_hist(t, i) | |
end | |
observed_terms = [obs.lhs for obs in observed(sys)] | |
observed_names = getname.(observed_terms) | |
i = findfirst(==(sym), observed_names) | |
if i !== nothing | |
# combine solution and param_hist | |
f = SciMLBase.getobserved(sol) # generated function | |
# the observed will be interpolated if the state it gets is interpolated | |
# and the parameters are current | |
term = observed_terms[i] | |
t -> f(term, sol(t), param_hist(t), t) | |
else | |
error("Not found in system: $sym") | |
end | |
end | |
@variables t x(t) RHS(t) # independent and dependent variables | |
@parameters τ # parameters | |
D = Differential(t) # define an operator for the differentiation w.r.t. time | |
# your first ODE, consisting of a single equation, indicated by ~ | |
@named sys = ODESystem([RHS ~ (1 - x) / τ, | |
D(x) ~ RHS]) | |
sim = structural_simplify(sys) | |
# A place to store the parameter values over time. The default solution object does not | |
# track these, and will only show the latest value. To be able to plot observed states that | |
# depend on parameters correctly, we need to save them over time. We can only save them | |
# after updating them, so the timesteps don't match the saved timestamps in the solution. | |
param_hist = ForwardFill(Float64[], Vector{Float64}[]) | |
prob = ODEProblem(sim, [x => 0.0], (0.0, 10.0), [τ => 3.0]) | |
# the τ values as they change over time | |
τs = ForwardFill([0.0, 4.0, 8.0], [3.0, 5.0, 2.0]) | |
function periodic_update!(integrator) | |
(; t, p) = integrator | |
set_param!(integrator, sys, τ, τs(t)) | |
save!(param_hist, t, p) | |
return nothing | |
end | |
cb = PeriodicCallback(periodic_update!, 4.0; initial_affect = true) | |
integrator = init(prob, DE.Rodas4(), callback = cb, save_on = true) | |
solve!(integrator) | |
sol = integrator.sol | |
# get interpolator functions and use them to plot with Makie | |
x_itp = interpolator(sim, integrator, param_hist, x) # state | |
RHS_itp = interpolator(sim, integrator, param_hist, RHS) # observed | |
τ_itp = interpolator(sim, integrator, param_hist, τ) # (dynamic) parameter | |
begin | |
timespan = 0.0 .. 10.0 | |
fig = Figure() | |
ax1 = Axis(fig[1, 1]) | |
ax2 = Axis(fig[2, 1]) | |
lines!(ax1, timespan, x_itp, label = "x") | |
lines!(ax1, timespan, RHS_itp, label = "RHS") | |
axislegend(ax1) | |
lines!(ax2, timespan, τ_itp, label = "τ") | |
axislegend(ax2) | |
fig | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment