Last active
July 11, 2023 19:29
-
-
Save slwu89/f6d0c595cea9cc1a86acda470709b140 to your computer and use it in GitHub Desktop.
sandbox for SIR particle filter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using StochasticAD, Distributions, StaticArrays, Plots | |
using Zygote, ForwardDiff | |
# map rates to probabilities | |
function rate_to_proportion(r, t) | |
1-exp(-r*t) | |
end | |
# dynamic part of SIR model | |
function sir_dyn_mod(x, u0, p, δt) | |
(β,c,γ) = p | |
S, I, R = u0 | |
@inbounds begin | |
S -= x[1] | |
I += x[1] | |
I -= x[2] | |
R += x[2] | |
end | |
N = S+I+R | |
inf_prob = rate_to_proportion(β*c*I/N, δt) | |
rec_prob = rate_to_proportion(γ, δt) | |
inf_rv = Distributions.Binomial(S, inf_prob) | |
rec_rv = Distributions.Binomial(I, rec_prob) | |
N_inf = x[1] | |
N_rec = x[2] | |
Distributions.product_distribution([inf_rv + N_inf, rec_rv + N_rec]) | |
end | |
# observation part of SIR model (assume we only observe infections) | |
function sir_obs_mod(x, u0, p) | |
S, I, R = u0 | |
@inbounds begin | |
S -= x[1] | |
I += x[1] | |
I -= x[2] | |
R += x[2] | |
end | |
Poisson(I) | |
end | |
# simulate a single trajectory | |
function simulate_single(u0, p, nsteps, δt) | |
x = zeros(Int, 2) | |
y = u0[2] | |
xs = zeros(Int, nsteps+1, 2) | |
ys = zeros(Int, nsteps+1) | |
ys[1] = y | |
for n in 1:nsteps | |
x = rand(sir_dyn_mod(x, u0, p, δt)) | |
y = rand(sir_obs_mod(x, u0, p)) | |
xs[n+1,:] = x | |
ys[n+1] = y | |
end | |
return xs, ys | |
end | |
# for plotting, convert counts of events to SIR state variables | |
function convert_count_to_state(xs, u0) | |
state = zeros(Int, size(xs,1), 3) | |
state[1,:] = u0 | |
for n in axes(xs,1)[2:end] | |
S, I, R = u0 | |
x = xs[n,:] | |
@inbounds begin | |
S -= x[1] | |
I += x[1] | |
I -= x[2] | |
R += x[2] | |
end | |
state[n,:] = [S,I,R] | |
end | |
return state | |
end | |
# parameters | |
tmax = 40.0 | |
δt = 0.1 | |
t = 0:δt:tmax; | |
nsteps = Int(tmax / δt); | |
u0 = [990,10,0]; # S,I,R | |
p = [0.05,10.0,0.25]; # β,c,γ,δt | |
# simulate data to use as "observations" | |
xs, ys = simulate_single(u0, p, nsteps, δt) | |
ys = ys[1:end-1] | |
ys_sample_rate = [1; 10:10:nsteps] | |
ys = ys[ys_sample_rate] | |
state = convert_count_to_state(xs, u0) | |
plot(state, label=["S" "I" "R"]) | |
scatter!(ys_sample_rate, ys, label="Observed I", alpha=0.5) | |
# resampling | |
function sample_stratified(p, K, sump=1) | |
n = length(p) | |
U = rand() | |
is = zeros(Int, K) | |
i = 1 | |
cw = p[1] | |
for k in 1:K | |
t = sump * (k - 1 + U) / K | |
while cw < t && i < n | |
i += 1 | |
@inbounds cw += p[i] | |
end | |
is[k] = i | |
end | |
return is | |
end | |
function resample(m, X, W, ω, use_new_weight=true) | |
js = Zygote.ignore(() -> sample_stratified(W, m, ω)) | |
X_new = X[js] | |
if use_new_weight | |
# differentiable resampling | |
W_chosen = W[js] | |
W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen) | |
else | |
# stop gradient, biased approach | |
W_new = fill(ω / m, m) | |
end | |
X_new, W_new | |
end | |
# the simple bootstrap particle filter | |
function sir_particle_filter(m, dyn, obs, u0, p, δt, nsteps, y, y_times; store_path, use_new_weight) | |
X = [zeros(Int,2) for _ in 1:m] # particles | |
W = [1 / m for _ in 1:m] # weights | |
ω = 1 # total weight | |
store_path && (Xs = [X]) | |
for n in 1:nsteps | |
# do we need to recalculate likelihood? | |
y_ix = findfirst(x->x==n, y_times) | |
if !isnothing(y_ix) | |
# update weights & likelihood using observations | |
wi = map(x -> pdf(obs(x, u0, p), y[y_ix]), X) | |
W = W .* wi | |
ω = sum(W) | |
# resample particles | |
X, W = resample(m, X, W, ω, use_new_weight) | |
end | |
# update particle trajectory | |
X = map(x -> rand(dyn(x, u0, p, δt)), X) | |
store_path && Zygote.ignore(() -> push!(Xs, X)) | |
end | |
(store_path ? Xs : X), W | |
end | |
# plot and test | |
Xs, W = sir_particle_filter( | |
100, sir_dyn_mod, sir_obs_mod, u0, p, δt, nsteps, ys, ys_sample_rate, | |
store_path=true, use_new_weight=true | |
) | |
trajs = [transpose(hcat(map(x->x[i], Xs)...)) for i in 1:100] | |
trajs = map(x->convert_count_to_state(x, u0)[:,2], trajs) | |
trajs = hcat(trajs...) | |
scatter(ys_sample_rate, ys, label="Observed") | |
plot!(trajs, label=false, alpha=0.1, color=:black) | |
# try to differentiate the log likelihood | |
function log_likelihood(p, m, use_new_weight=true) | |
_, W = sir_particle_filter( | |
m, sir_dyn_mod, sir_obs_mod, u0, p, δt, nsteps, ys, ys_sample_rate, | |
store_path=false, use_new_weight=use_new_weight | |
) | |
log(sum(W)) | |
end | |
grad = ForwardDiff.gradient(p -> log_likelihood(p, 100, true), p) | |
grad = ForwardDiff.gradient(p -> log_likelihood(p, 100, false), p) | |
grad = Zygote.gradient(p -> log_likelihood(p, 100, true), p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment