Last active
September 9, 2016 17:41
-
-
Save gasagna/abfe4f22d197154b1412a531cd0b4f04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: start, done, next, length | |
import Sundials: N_Vector, | |
NVector, | |
realtype, | |
DlsMat, | |
CVodeInit, | |
CVodeCreate, | |
CVodeSStolerances, | |
CVodeReInit, | |
CVodeSetUserData, | |
CVDense, | |
CV_SUCCESS, | |
CV_NEWTON, | |
CV_NORMAL, | |
CV_BDF, | |
CVodeFree, | |
CVODEMem, | |
CVode | |
type ODEProblem | |
f::Function | |
end | |
function fcallback(t::realtype, x::N_Vector, ẋ::N_Vector, prob::ODEProblem) | |
prob.f(t, convert(Vector, x), convert(Vector, ẋ)) | |
return CV_SUCCESS | |
end | |
# julia wrapper around cvode object | |
type JuliaCVode | |
mem::Ptr{CVODEMem} | |
xnv::NVector # nvector for the solution | |
x₀::Vector{Float64} # reset state from same initial condition | |
tout::Vector{Float64} # avoid reallocating this | |
function JuliaCVode(prob::ODEProblem, x₀::AbstractVector, rtol::Float64, atol::Float64) | |
# setup memory | |
mem = CVodeCreate(CV_BDF, CV_NEWTON) | |
mem == C_NULL && error("Failed to allocate CVODE solver object") | |
# make callback | |
fun = cfunction(fcallback, | |
Cint, | |
(realtype, N_Vector, N_Vector, Ref{ODEProblem})) | |
xnv = NVector(copy(x₀)) | |
CVodeSetUserData(mem, prob) | |
CVodeInit(mem, fun, 0.0, convert(N_Vector, copy(x₀))) | |
CVodeSStolerances(mem, rtol, atol) | |
CVDense(mem, length(x₀)) | |
new(mem, xnv, x₀, [0.0]) | |
end | |
end | |
# ~~~ Custom Types ~~~ | |
# type to sample the state at given intervals | |
type Sampler | |
jcv::JuliaCVode # julia wrapper to CVode integrator | |
Δt::Float64 # time step | |
end | |
# external constructors | |
sample(f::Function, x₀::AbstractVector, Δt::Real; | |
rtol::Float64=1e-6, atol::Float64=1e-6) = | |
Sampler(JuliaCVode(ODEProblem(f), x₀, rtol, atol), Δt) | |
# Iteration protocol | |
length(s::Sampler) = typemax(Int64) | |
# reset state | |
function start(s::Sampler) | |
s.jcv.tout[1] = 0.0 | |
s.jcv.xnv[:] = s.jcv.x₀ | |
CVodeReInit(s.jcv.mem, s.jcv.tout[1], s.jcv.xnv) | |
s.jcv.tout[1] | |
end | |
function next(s::Sampler, tcurr) | |
tnext = tcurr + s.Δt | |
CVode(s.jcv.mem, tnext, s.jcv.xnv, s.jcv.tout, CV_NORMAL) | |
ret = convert(Vector{Float64}, s.jcv.xnv), tnext | |
ret::Tuple{Vector{Float64}, Float64} | |
end | |
# never ends | |
done(s::Sampler, tcurr) = false | |
# ~~~ CODE ~~~ | |
function lorenz(t::Real, x::AbstractVector, ẋ::AbstractVector) | |
ẋ[1] = 10 * (x[2] - x[1]) | |
ẋ[2] = 28 * x[1] - x[2] - x[1]*x[3] | |
ẋ[3] = -8/3 * x[3] + x[1]*x[2] | |
end | |
# construct sampler | |
S = sample(lorenz, [1, 1, 28.0], 1.0; rtol=1e-3, atol=1e-3) | |
# take first 100 samples | |
for (i, s) in zip(1:100, S) | |
println(i, s) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment