Skip to content

Instantly share code, notes, and snippets.

@axsk
Created March 27, 2016 00:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save axsk/7cbffdbc2c9b7dae1077 to your computer and use it in GitHub Desktop.
Save axsk/7cbffdbc2c9b7dae1077 to your computer and use it in GitHub Desktop.
cvodes-autodiff
using Sundials, ForwardDiff
import Sundials: realtype, N_Vector
type FAndP
f::Function
p::Vector{Float64}
end
function unzip(fp::Ptr{Void})
fp = unsafe_pointer_to_objref(fp) :: FAndP
(fp.f, fp.p)
end
function cvodesfun(t, y, dy, fp)
y = Sundials.asarray(y)
dy = Sundials.asarray(dy)
f,p = unzip(fp)
f(t, y, p, dy)
return Int32(0)
end
function differentiator(f,ny,np)
y = Vector(ny)
p = Vector(np)
dy = Vector(ny)
J = Matrix{Float64}(ny, ny+np)
function merged(x)
y[:]=x[1:ny]
p[:]=x[ny+(1:np)]
f(0,y,p,dy) # TODO fix time dependence
dy
end
j! = ForwardDiff.jacobian(merged, mutates=true)
(y0,p) -> j!(J, vcat(y0,p))
end
function sensrhsfn(ns::Int32, t::realtype, y::N_Vector, ydot::N_Vector, yS::N_Vector, ySdot::N_Vector, user_data::Ptr{Void}, tmp1::N_Vector, tmp2::N_Vector)
#@show Sundials.asarray(yS)
@show unsafe_load(yS,1) |> pointer |> Sundials.asarray
np = ns
f, p = unzip(user_data)
y = Sundials.asarray(y)
yS = pointer_to_array(yS, np)
ySdot = pointer_to_array(ySdot, np)
@show typeof(ySdot)
@show Sundials.asarray(yS[1])
#@show yS[1]
D!(y, p)
for i in 1:np
ySi = Sundials.asarray(yS[i])
ySdot[i] = Sundials.nvector(J[:,1:ny] * ySi + J[:,ny+i])
end
return Int32(0)
end
# expect function signature f=f!(t, y0, p, dy)
function cvodes(f::Function, y0::Vector{Float64}, p::Vector{Float64}, ts::Vector{Float64}; reltol=1e-8, abstol=1e-6, autodiff=true)
ny = length(y0)
np = length(p)
### Initialize automatic differentiator
D! = differentiator(f, ny, np)
sensrhsfnptr = cfunction(sensrhsfn, Int32, (Int32, realtype, N_Vector, N_Vector, N_Vector, N_Vector, Ptr{Void}, N_Vector, N_Vector))
### CVode settings ###
cvode_mem = Sundials.CVodeCreate(Sundials.CV_BDF, Sundials.CV_NEWTON)
Sundials.CVodeInit(cvode_mem, cvodesfun, ts[1], y0)
Sundials.CVodeSetUserData(cvode_mem, FAndP(f,p))
Sundials.CVodeSStolerances(cvode_mem, reltol, abstol)
Sundials.CVDense(cvode_mem, ny)
### Sensiviy Settings ###
yS = [Sundials.nvector(zeros(Float64, ny)) for i in 1:np] |> pointer
if autodiff
Sundials.CVodeSensInit(cvode_mem, np, Sundials.CV_SIMULTANEOUS, sensrhsfn, yS);
else
Sundials.CVodeSensInit(cvode_mem, np, Sundials.CV_SIMULTANEOUS, Ptr{Void}(0), yS);
Sundials.CVodeSetSensDQMethod(cvode_mem, Sundials.CV_CENTERED, 0.0);
end
Sundials.CVodeSetSensParams(cvode_mem, p, p, Ptr{Int32}(0));
Sundials.CVodeSetSensErrCon(cvode_mem, 0);
Sundials.CVodeSensEEtolerances(cvode_mem);
#Sundials.CVodeSensSStolerances(cvode_mem, reltol, sens_tol_vec);
# Placeholder for solution and sensitivities
solution = zeros(length(ts), ny)
solution[1,:] = copy(y0)
sens = zeros(length(ts),ny,np) # No need to copy initial condition, they are already zero
tout = [0.] # output time reached by the solver
yout = copy(y0)
# Loop through all the output times
for k in 2:length(ts)
# Extract the solution to x, and the sensitivities to yS
Sundials.CVode(cvode_mem, ts[k], yout, tout, Sundials.CV_NORMAL)
Sundials.CVodeGetSens(cvode_mem, tout, yS)
#Save the results
solution[k,:] = yout
for i in 1:np
sens[k,:,i] = Sundials.asarray(unsafe_load(yS,i))
end
end
return (solution,sens)
end
function f(t,y,p,dy)
dy[1] = p[1]
dy[2] = p[2]
end
cvodes(f, [.5,0], [1,2.], collect(linspace(0,1,10)), autodiff = true)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment