Created
April 4, 2018 00:47
-
-
Save alanderos91/3328b96281b35d7eb24d0a0fdee12da9 to your computer and use it in GitHub Desktop.
SSAJumpAggregator Interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this provides only the required fields, but we could in principle add more; e.g. dependency graph | |
type DirectJumpAggregation{T,F1,F2,RNG} <: SSAJumpAggregator | |
next_jump::Int | |
next_jump_time::T | |
end_time::T | |
cur_rates::Vector{T} | |
sum_rate::T | |
rates::F1 | |
affects!::F2 | |
save_positions::Tuple{Bool,Bool} | |
rng::RNG | |
end | |
# template for condition | |
@inline function (p::DirectJumpAggregation)(u,t,integrator) # condition | |
p.next_jump_time==t | |
end | |
# template for affect! | |
# Note that `ttnj` is unused here, but an algorithm like NRM would use this in some intermediate step. | |
function (p::DirectJumpAggregation)(integrator) # affect! | |
ttnj, i = retrieve_jump(p) | |
@inbounds p.affects![i](integrator) | |
generate_jump!(p,integrator.u,integrator.p,integrator.t) | |
if p.next_jump_time < p.end_time | |
add_tstop!(integrator,p.next_jump_time) | |
end | |
nothing | |
end | |
##### implementation details ##### | |
# this handles all the sampling | |
function generate_jump!(p::DirectJumpAggregation,u,params,t) | |
# update the jump rates | |
sum_rate = cur_rates_as_cumsum(u,params,t,p.rates,p.cur_rates) | |
# determine next jump index | |
i = randidx_bisection(p.cur_rates, rand(p.rng)) | |
# determine next jump time | |
ttnj = randexp_ziggurat(p.rng,sum_rate) | |
# mutate fields | |
p.sum_rate = sum_rate | |
p.next_jump = i | |
p.next_jump_time = t + ttnj | |
nothing | |
end | |
# this sets up data structures | |
function (p::DirectJumpAggregation)(dj,u,t,integrator) # initialize | |
generate_jump!(p,u,integrator.p,t) | |
if p.next_jump_time < p.end_time | |
add_tstop!(integrator,p.next_jump_time) | |
end | |
nothing | |
end | |
@inline function aggregate(aggregator::Direct,u,p,t,end_time,constant_jumps,save_positions,rng) | |
rates = ((c.rate for c in constant_jumps)...) | |
affects! = ((c.affect! for c in constant_jumps)...) | |
cur_rates = Vector{Float64}(length(rates)) | |
sum_rate = zero(Float64) | |
next_jump = 0 | |
next_jump_time = typemax(Float64) | |
DirectJumpAggregation(next_jump,next_jump_time,end_time,cur_rates, | |
sum_rate,rates,affects!,save_positions,rng) | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An aggregator interface for SSA-like algorithms. | |
### Required Fields | |
- `next_jump` | |
- `next_jump_time` | |
- `end_time` | |
- `cur_rates` | |
- `sum_rate` | |
- `rates` | |
- `affects!` | |
- `save_positions` | |
- `rng` | |
### Required Methods | |
- `(p::SSAJumpAggregator)(dj,u,t,integrator)`: an initialization functor | |
- `aggregate(aggregator::AbstractAggregatorAlgorithm,u,p,t,end_time,constant_jumps,save_positions)` | |
- `generate_jump(p, integrator)`: | |
""" | |
abstract type SSAJumpAggregator <: AbstractJumpAggregator end | |
##### defaults ##### | |
@inline retrieve_jump(p::SSAJumpAggregator) = p.next_jump_time,p.next_jump | |
# forbidden; see: https://github.com/JuliaLang/julia/issues/14919 | |
# @inline function (p::SSAJumpAggregator)(u,t,integrator) # condition | |
# p.next_jump_time==t | |
# end | |
# function (p::SSAJumpAggregator)(integrator) # affect! | |
# ttnj, i = retrieve_jump(p) | |
# @inbounds p.affects![i](integrator) | |
# generate_jump!(p,integrator.u,integrator.p,integrator.t) | |
# if p.next_jump_time < p.end_time | |
# add_tstop!(integrator,p.next_jump_time) | |
# end | |
# nothing | |
# end | |
DiscreteCallback(c::SSAJumpAggregator) = DiscreteCallback(c,c,initialize=c,save_positions=c.save_positions) | |
##### required methods ##### | |
generate_jump!(p::SSAJumpAggregator,u,params,t) = nothing | |
# (p::SSAJumpAggregator)(dj,u,t,integrator) = nothing # initialize | |
aggregate(aggregator::AbstractAggregatorAlgorithm,u,p,t,end_time,constant_jumps,save_positions) = nothing | |
##### helper functions for updating rates ##### | |
@inline function fill_cur_rates(u,p,t,cur_rates,idx,rate,rates...) | |
@inbounds cur_rates[idx] = rate(u,p,t) | |
idx += 1 | |
fill_cur_rates(u,p,t,cur_rates,idx,rates...) | |
end | |
@inline function fill_cur_rates(u,p,t,cur_rates,idx,rate) | |
@inbounds cur_rates[idx] = rate(u,p,t) | |
nothing | |
end | |
function cur_rates_as_cumsum(u,p,t,rates,cur_rates) | |
@inbounds fill_cur_rates(u,p,t,cur_rates,1,rates...) | |
sum_rate = sum(cur_rates) | |
@fastmath normalizer = 1/sum_rate | |
@inbounds cur_rates[1] = normalizer*cur_rates[1] | |
@inbounds for i in 2:length(cur_rates) # normalize for choice, cumsum | |
cur_rates[i] = normalizer*cur_rates[i] + cur_rates[i-1] | |
end | |
sum_rate | |
end | |
##### helper functions for sampling jump times ##### | |
@inline randexp_ziggurat(sum_rate) = randexp() / sum_rate | |
@inline randexp_ziggurat(rng,sum_rate) = randexp(rng) / sum_rate | |
@inline randexp_inverse(sum_rate) = -log(rand()) / sum_rate | |
@inline randexp_inverse(rng,sum_rate) = -log(rand(rng)) / sum_rate | |
##### helper functions for sampling jump indices ##### | |
@inline randidx_bisection(cur_rates,rng_val) = searchsortedfirst(cur_rates,rng_val) | |
##### helper functions for coupled sampling ##### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment