Skip to content

Instantly share code, notes, and snippets.

@alanderos91
Created April 4, 2018 00:47
Show Gist options
  • Save alanderos91/3328b96281b35d7eb24d0a0fdee12da9 to your computer and use it in GitHub Desktop.
Save alanderos91/3328b96281b35d7eb24d0a0fdee12da9 to your computer and use it in GitHub Desktop.
SSAJumpAggregator Interface
# this provides only the required fields, but we could in principle add more; e.g. dependency graph
type DirectJumpAggregation{T,F1,F2,RNG} <: SSAJumpAggregator
next_jump::Int
next_jump_time::T
end_time::T
cur_rates::Vector{T}
sum_rate::T
rates::F1
affects!::F2
save_positions::Tuple{Bool,Bool}
rng::RNG
end
# template for condition
@inline function (p::DirectJumpAggregation)(u,t,integrator) # condition
p.next_jump_time==t
end
# template for affect!
# Note that `ttnj` is unused here, but an algorithm like NRM would use this in some intermediate step.
function (p::DirectJumpAggregation)(integrator) # affect!
ttnj, i = retrieve_jump(p)
@inbounds p.affects![i](integrator)
generate_jump!(p,integrator.u,integrator.p,integrator.t)
if p.next_jump_time < p.end_time
add_tstop!(integrator,p.next_jump_time)
end
nothing
end
##### implementation details #####
# this handles all the sampling
function generate_jump!(p::DirectJumpAggregation,u,params,t)
# update the jump rates
sum_rate = cur_rates_as_cumsum(u,params,t,p.rates,p.cur_rates)
# determine next jump index
i = randidx_bisection(p.cur_rates, rand(p.rng))
# determine next jump time
ttnj = randexp_ziggurat(p.rng,sum_rate)
# mutate fields
p.sum_rate = sum_rate
p.next_jump = i
p.next_jump_time = t + ttnj
nothing
end
# this sets up data structures
function (p::DirectJumpAggregation)(dj,u,t,integrator) # initialize
generate_jump!(p,u,integrator.p,t)
if p.next_jump_time < p.end_time
add_tstop!(integrator,p.next_jump_time)
end
nothing
end
@inline function aggregate(aggregator::Direct,u,p,t,end_time,constant_jumps,save_positions,rng)
rates = ((c.rate for c in constant_jumps)...)
affects! = ((c.affect! for c in constant_jumps)...)
cur_rates = Vector{Float64}(length(rates))
sum_rate = zero(Float64)
next_jump = 0
next_jump_time = typemax(Float64)
DirectJumpAggregation(next_jump,next_jump_time,end_time,cur_rates,
sum_rate,rates,affects!,save_positions,rng)
end
"""
An aggregator interface for SSA-like algorithms.
### Required Fields
- `next_jump`
- `next_jump_time`
- `end_time`
- `cur_rates`
- `sum_rate`
- `rates`
- `affects!`
- `save_positions`
- `rng`
### Required Methods
- `(p::SSAJumpAggregator)(dj,u,t,integrator)`: an initialization functor
- `aggregate(aggregator::AbstractAggregatorAlgorithm,u,p,t,end_time,constant_jumps,save_positions)`
- `generate_jump(p, integrator)`:
"""
abstract type SSAJumpAggregator <: AbstractJumpAggregator end
##### defaults #####
@inline retrieve_jump(p::SSAJumpAggregator) = p.next_jump_time,p.next_jump
# forbidden; see: https://github.com/JuliaLang/julia/issues/14919
# @inline function (p::SSAJumpAggregator)(u,t,integrator) # condition
# p.next_jump_time==t
# end
# function (p::SSAJumpAggregator)(integrator) # affect!
# ttnj, i = retrieve_jump(p)
# @inbounds p.affects![i](integrator)
# generate_jump!(p,integrator.u,integrator.p,integrator.t)
# if p.next_jump_time < p.end_time
# add_tstop!(integrator,p.next_jump_time)
# end
# nothing
# end
DiscreteCallback(c::SSAJumpAggregator) = DiscreteCallback(c,c,initialize=c,save_positions=c.save_positions)
##### required methods #####
generate_jump!(p::SSAJumpAggregator,u,params,t) = nothing
# (p::SSAJumpAggregator)(dj,u,t,integrator) = nothing # initialize
aggregate(aggregator::AbstractAggregatorAlgorithm,u,p,t,end_time,constant_jumps,save_positions) = nothing
##### helper functions for updating rates #####
@inline function fill_cur_rates(u,p,t,cur_rates,idx,rate,rates...)
@inbounds cur_rates[idx] = rate(u,p,t)
idx += 1
fill_cur_rates(u,p,t,cur_rates,idx,rates...)
end
@inline function fill_cur_rates(u,p,t,cur_rates,idx,rate)
@inbounds cur_rates[idx] = rate(u,p,t)
nothing
end
function cur_rates_as_cumsum(u,p,t,rates,cur_rates)
@inbounds fill_cur_rates(u,p,t,cur_rates,1,rates...)
sum_rate = sum(cur_rates)
@fastmath normalizer = 1/sum_rate
@inbounds cur_rates[1] = normalizer*cur_rates[1]
@inbounds for i in 2:length(cur_rates) # normalize for choice, cumsum
cur_rates[i] = normalizer*cur_rates[i] + cur_rates[i-1]
end
sum_rate
end
##### helper functions for sampling jump times #####
@inline randexp_ziggurat(sum_rate) = randexp() / sum_rate
@inline randexp_ziggurat(rng,sum_rate) = randexp(rng) / sum_rate
@inline randexp_inverse(sum_rate) = -log(rand()) / sum_rate
@inline randexp_inverse(rng,sum_rate) = -log(rand(rng)) / sum_rate
##### helper functions for sampling jump indices #####
@inline randidx_bisection(cur_rates,rng_val) = searchsortedfirst(cur_rates,rng_val)
##### helper functions for coupled sampling #####
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment