Skip to content

Instantly share code, notes, and snippets.

@gasagna
Last active September 9, 2016 17:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gasagna/abfe4f22d197154b1412a531cd0b4f04 to your computer and use it in GitHub Desktop.
Save gasagna/abfe4f22d197154b1412a531cd0b4f04 to your computer and use it in GitHub Desktop.
import Base: start, done, next, length
import Sundials: N_Vector,
NVector,
realtype,
DlsMat,
CVodeInit,
CVodeCreate,
CVodeSStolerances,
CVodeReInit,
CVodeSetUserData,
CVDense,
CV_SUCCESS,
CV_NEWTON,
CV_NORMAL,
CV_BDF,
CVodeFree,
CVODEMem,
CVode
type ODEProblem
f::Function
end
function fcallback(t::realtype, x::N_Vector, ẋ::N_Vector, prob::ODEProblem)
prob.f(t, convert(Vector, x), convert(Vector, ẋ))
return CV_SUCCESS
end
# julia wrapper around cvode object
type JuliaCVode
mem::Ptr{CVODEMem}
xnv::NVector # nvector for the solution
x₀::Vector{Float64} # reset state from same initial condition
tout::Vector{Float64} # avoid reallocating this
function JuliaCVode(prob::ODEProblem, x₀::AbstractVector, rtol::Float64, atol::Float64)
# setup memory
mem = CVodeCreate(CV_BDF, CV_NEWTON)
mem == C_NULL && error("Failed to allocate CVODE solver object")
# make callback
fun = cfunction(fcallback,
Cint,
(realtype, N_Vector, N_Vector, Ref{ODEProblem}))
xnv = NVector(copy(x₀))
CVodeSetUserData(mem, prob)
CVodeInit(mem, fun, 0.0, convert(N_Vector, copy(x₀)))
CVodeSStolerances(mem, rtol, atol)
CVDense(mem, length(x₀))
new(mem, xnv, x₀, [0.0])
end
end
# ~~~ Custom Types ~~~
# type to sample the state at given intervals
type Sampler
jcv::JuliaCVode # julia wrapper to CVode integrator
Δt::Float64 # time step
end
# external constructors
sample(f::Function, x₀::AbstractVector, Δt::Real;
rtol::Float64=1e-6, atol::Float64=1e-6) =
Sampler(JuliaCVode(ODEProblem(f), x₀, rtol, atol), Δt)
# Iteration protocol
length(s::Sampler) = typemax(Int64)
# reset state
function start(s::Sampler)
s.jcv.tout[1] = 0.0
s.jcv.xnv[:] = s.jcv.x₀
CVodeReInit(s.jcv.mem, s.jcv.tout[1], s.jcv.xnv)
s.jcv.tout[1]
end
function next(s::Sampler, tcurr)
tnext = tcurr + s.Δt
CVode(s.jcv.mem, tnext, s.jcv.xnv, s.jcv.tout, CV_NORMAL)
ret = convert(Vector{Float64}, s.jcv.xnv), tnext
ret::Tuple{Vector{Float64}, Float64}
end
# never ends
done(s::Sampler, tcurr) = false
# ~~~ CODE ~~~
function lorenz(t::Real, x::AbstractVector, ẋ::AbstractVector)
ẋ[1] = 10 * (x[2] - x[1])
ẋ[2] = 28 * x[1] - x[2] - x[1]*x[3]
ẋ[3] = -8/3 * x[3] + x[1]*x[2]
end
# construct sampler
S = sample(lorenz, [1, 1, 28.0], 1.0; rtol=1e-3, atol=1e-3)
# take first 100 samples
for (i, s) in zip(1:100, S)
println(i, s)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment