Skip to content

Instantly share code, notes, and snippets.

@slwu89
Last active July 11, 2023 19:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slwu89/f6d0c595cea9cc1a86acda470709b140 to your computer and use it in GitHub Desktop.
Save slwu89/f6d0c595cea9cc1a86acda470709b140 to your computer and use it in GitHub Desktop.
sandbox for SIR particle filter
using StochasticAD, Distributions, StaticArrays, Plots
using Zygote, ForwardDiff
# map rates to probabilities
function rate_to_proportion(r, t)
1-exp(-r*t)
end
# dynamic part of SIR model
function sir_dyn_mod(x, u0, p, δt)
(β,c,γ) = p
S, I, R = u0
@inbounds begin
S -= x[1]
I += x[1]
I -= x[2]
R += x[2]
end
N = S+I+R
inf_prob = rate_to_proportion(β*c*I/N, δt)
rec_prob = rate_to_proportion(γ, δt)
inf_rv = Distributions.Binomial(S, inf_prob)
rec_rv = Distributions.Binomial(I, rec_prob)
N_inf = x[1]
N_rec = x[2]
Distributions.product_distribution([inf_rv + N_inf, rec_rv + N_rec])
end
# observation part of SIR model (assume we only observe infections)
function sir_obs_mod(x, u0, p)
S, I, R = u0
@inbounds begin
S -= x[1]
I += x[1]
I -= x[2]
R += x[2]
end
Poisson(I)
end
# simulate a single trajectory
function simulate_single(u0, p, nsteps, δt)
x = zeros(Int, 2)
y = u0[2]
xs = zeros(Int, nsteps+1, 2)
ys = zeros(Int, nsteps+1)
ys[1] = y
for n in 1:nsteps
x = rand(sir_dyn_mod(x, u0, p, δt))
y = rand(sir_obs_mod(x, u0, p))
xs[n+1,:] = x
ys[n+1] = y
end
return xs, ys
end
# for plotting, convert counts of events to SIR state variables
function convert_count_to_state(xs, u0)
state = zeros(Int, size(xs,1), 3)
state[1,:] = u0
for n in axes(xs,1)[2:end]
S, I, R = u0
x = xs[n,:]
@inbounds begin
S -= x[1]
I += x[1]
I -= x[2]
R += x[2]
end
state[n,:] = [S,I,R]
end
return state
end
# parameters
tmax = 40.0
δt = 0.1
t = 0:δt:tmax;
nsteps = Int(tmax / δt);
u0 = [990,10,0]; # S,I,R
p = [0.05,10.0,0.25]; # β,c,γ,δt
# simulate data to use as "observations"
xs, ys = simulate_single(u0, p, nsteps, δt)
ys = ys[1:end-1]
ys_sample_rate = [1; 10:10:nsteps]
ys = ys[ys_sample_rate]
state = convert_count_to_state(xs, u0)
plot(state, label=["S" "I" "R"])
scatter!(ys_sample_rate, ys, label="Observed I", alpha=0.5)
# resampling
function sample_stratified(p, K, sump=1)
n = length(p)
U = rand()
is = zeros(Int, K)
i = 1
cw = p[1]
for k in 1:K
t = sump * (k - 1 + U) / K
while cw < t && i < n
i += 1
@inbounds cw += p[i]
end
is[k] = i
end
return is
end
function resample(m, X, W, ω, use_new_weight=true)
js = Zygote.ignore(() -> sample_stratified(W, m, ω))
X_new = X[js]
if use_new_weight
# differentiable resampling
W_chosen = W[js]
W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen)
else
# stop gradient, biased approach
W_new = fill(ω / m, m)
end
X_new, W_new
end
# the simple bootstrap particle filter
function sir_particle_filter(m, dyn, obs, u0, p, δt, nsteps, y, y_times; store_path, use_new_weight)
X = [zeros(Int,2) for _ in 1:m] # particles
W = [1 / m for _ in 1:m] # weights
ω = 1 # total weight
store_path && (Xs = [X])
for n in 1:nsteps
# do we need to recalculate likelihood?
y_ix = findfirst(x->x==n, y_times)
if !isnothing(y_ix)
# update weights & likelihood using observations
wi = map(x -> pdf(obs(x, u0, p), y[y_ix]), X)
W = W .* wi
ω = sum(W)
# resample particles
X, W = resample(m, X, W, ω, use_new_weight)
end
# update particle trajectory
X = map(x -> rand(dyn(x, u0, p, δt)), X)
store_path && Zygote.ignore(() -> push!(Xs, X))
end
(store_path ? Xs : X), W
end
# plot and test
Xs, W = sir_particle_filter(
100, sir_dyn_mod, sir_obs_mod, u0, p, δt, nsteps, ys, ys_sample_rate,
store_path=true, use_new_weight=true
)
trajs = [transpose(hcat(map(x->x[i], Xs)...)) for i in 1:100]
trajs = map(x->convert_count_to_state(x, u0)[:,2], trajs)
trajs = hcat(trajs...)
scatter(ys_sample_rate, ys, label="Observed")
plot!(trajs, label=false, alpha=0.1, color=:black)
# try to differentiate the log likelihood
function log_likelihood(p, m, use_new_weight=true)
_, W = sir_particle_filter(
m, sir_dyn_mod, sir_obs_mod, u0, p, δt, nsteps, ys, ys_sample_rate,
store_path=false, use_new_weight=use_new_weight
)
log(sum(W))
end
grad = ForwardDiff.gradient(p -> log_likelihood(p, 100, true), p)
grad = ForwardDiff.gradient(p -> log_likelihood(p, 100, false), p)
grad = Zygote.gradient(p -> log_likelihood(p, 100, true), p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment