Skip to content

Instantly share code, notes, and snippets.

@visr
Last active May 27, 2022 12:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save visr/5c5f5812552b6e261c2648c87233e9bd to your computer and use it in GitHub Desktop.
Save visr/5c5f5812552b6e261c2648c87233e9bd to your computer and use it in GitHub Desktop.
Save and plot parameter history with ModelingToolkit.jl
# Example from https://mtk.sciml.ai/dev/tutorials/ode_modeling/
# with a parameter τ that is updated through a callback.
# The parameter history is tracked, which allows us to create interpolating
# functions that show the correct history for parameters and observed variables
# that depend on variables.
# This works around limitations discussed here:
# https://discourse.julialang.org/t/update-parameters-in-modelingtoolkit-using-callbacks/63770
# And was posted here: https://discourse.julialang.org/t/save-and-plot-parameter-history-with-modelingtoolkit-jl/81778
# License is MIT
# Versions are ModelingToolkit v8.12 and CairoMakie v0.8
using ModelingToolkit
using DifferentialEquations: solve
import DifferentialEquations as DE
using DiffEqCallbacks: PeriodicCallback
using Symbolics: getname
using SciMLBase
using CairoMakie
"""
ForwardFill(t, v)
Create a callable struct that will give a value from v on or after a given t.
There is a tolerance of 1e-4 for t to avoid narrowly missing the next timestep.
v = rand(21)
ff = ForwardFill(0:0.1:2, v)
ff(0.1) == v[2]
ff(0.1 - 1e-5) == v[2]
ff(0.1 - 1e-3) == v[1]
"""
struct ForwardFill{T, V}
t::T
v::V
function ForwardFill(t::T, v::V) where {T, V}
n = length(t)
if n != length(v)
error("ForwardFill vectors are not of equal length")
end
if !issorted(t)
error("ForwardFill t is not sorted")
end
new{T, V}(t, v)
end
end
"Interpolate into a forward filled timeseries at t"
function (ff::ForwardFill{T, V})(t)::eltype(V) where {T, V}
# Subtract a small amount to avoid e.g. t = 2.999999s not picking up the t = 3s value.
# This can occur due to floating point issues with the calculated t::Float64
# The offset is larger than the eps of 1 My in seconds, and smaller than the periodic
# callback interval.
i = searchsortedlast(ff.t, t + 1e-4)
i == 0 && throw(DomainError(t, "Requesting t before start of series."))
return ff.v[i]
end
"Interpolate and get the index j of the result, useful for V=Vector{Vector{Float64}}"
function (ff::ForwardFill{T, V})(t, j)::eltype(eltype(V)) where {T, V}
i = searchsortedlast(ff.t, t + 1e-4)
i == 0 && throw(DomainError(t, "Requesting t before start of series."))
return ff.v[i][j]
end
"Update a parameter value"
function set_param!(integrator, sys, param, x::Real)::Real
p = integrator.p
param = getname(param)::Symbol
params = getname.(parameters(sys))
i = findfirst(==(param), params)
i === nothing && error("Parameter not found: $param")
return p[i] = x
end
"Add an entry to the parameter history"
function save!(param_hist::ForwardFill, t::Float64, p::Vector{Float64})
push!(param_hist.t, t)
push!(param_hist.v, copy(p))
return param_hist
end
"""
interpolator(sys, integrator, param_hist::ForwardFill, sym)::Function
Return a time interpolating function for the given a Symbolic or Symbol.
"""
function interpolator(sys, integrator, param_hist::ForwardFill, sym)::Function
sol = integrator.sol
# convert to Symbol to allow sym::Union{Symbolic, Num} input
sym = getname(sym)::Symbol
i = findfirst(==(sym), getname.(states(sys)))
if i !== nothing
# use solution as normal
return t -> sol(t, idxs = i)
end
i = findfirst(==(sym), getname.(parameters(sys)))
if i !== nothing
# use param_hist
return t -> param_hist(t, i)
end
observed_terms = [obs.lhs for obs in observed(sys)]
observed_names = getname.(observed_terms)
i = findfirst(==(sym), observed_names)
if i !== nothing
# combine solution and param_hist
f = SciMLBase.getobserved(sol) # generated function
# the observed will be interpolated if the state it gets is interpolated
# and the parameters are current
term = observed_terms[i]
t -> f(term, sol(t), param_hist(t), t)
else
error("Not found in system: $sym")
end
end
@variables t x(t) RHS(t) # independent and dependent variables
@parameters τ # parameters
D = Differential(t) # define an operator for the differentiation w.r.t. time
# your first ODE, consisting of a single equation, indicated by ~
@named sys = ODESystem([RHS ~ (1 - x) / τ,
D(x) ~ RHS])
sim = structural_simplify(sys)
# A place to store the parameter values over time. The default solution object does not
# track these, and will only show the latest value. To be able to plot observed states that
# depend on parameters correctly, we need to save them over time. We can only save them
# after updating them, so the timesteps don't match the saved timestamps in the solution.
param_hist = ForwardFill(Float64[], Vector{Float64}[])
prob = ODEProblem(sim, [x => 0.0], (0.0, 10.0), [τ => 3.0])
# the τ values as they change over time
τs = ForwardFill([0.0, 4.0, 8.0], [3.0, 5.0, 2.0])
function periodic_update!(integrator)
(; t, p) = integrator
set_param!(integrator, sys, τ, τs(t))
save!(param_hist, t, p)
return nothing
end
cb = PeriodicCallback(periodic_update!, 4.0; initial_affect = true)
integrator = init(prob, DE.Rodas4(), callback = cb, save_on = true)
solve!(integrator)
sol = integrator.sol
# get interpolator functions and use them to plot with Makie
x_itp = interpolator(sim, integrator, param_hist, x) # state
RHS_itp = interpolator(sim, integrator, param_hist, RHS) # observed
τ_itp = interpolator(sim, integrator, param_hist, τ) # (dynamic) parameter
begin
timespan = 0.0 .. 10.0
fig = Figure()
ax1 = Axis(fig[1, 1])
ax2 = Axis(fig[2, 1])
lines!(ax1, timespan, x_itp, label = "x")
lines!(ax1, timespan, RHS_itp, label = "RHS")
axislegend(ax1)
lines!(ax2, timespan, τ_itp, label = "τ")
axislegend(ax2)
fig
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment