Skip to content

Instantly share code, notes, and snippets.

@jackd
Created August 25, 2021 09:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jackd/dd81161726b661c3c2a651e039305f04 to your computer and use it in GitHub Desktop.
Save jackd/dd81161726b661c3c2a651e039305f04 to your computer and use it in GitHub Desktop.
A super-simple Susceptible-Infection-Recovered (SIR) model with vaccination
"""Basic Susceptible-Infectious-Recovered model."""
import numpy as np
import matplotlib.pyplot as plt
def run(
infectious: int,
r0=1.07, # base daily growth rate
vax_rate=0,
init_vax_prop=0,
population=24000000,
steps=180,
):
susceptible = population * (1 - init_vax_prop) - infectious
inf = np.zeros((steps,))
inf[0] = infectious
for s in range(1, steps):
infectious *= susceptible / population * r0
susceptible -= infectious + vax_rate
infectious = max(infectious, 0)
susceptible = max(susceptible, 0)
inf[s] = infectious
return inf
title = "SIR model with vaccination"
steps = 100
kwargs = dict(
init_vax_prop=0.7 * 0.8,
vax_rate=25000 * 0.8,
r0=1.07 / (1 - 0.7 * 0.8),
)
twinx = True
x = np.exp(np.linspace(0, np.log(100000), 11))
y = [np.sum(run(xi, steps=1000, **kwargs)) for xi in x]
plt.loglog(x, y)
plt.xlabel("Initial infections")
plt.ylabel("Total infections")
plt.title("SIR model with 25,000 vaccinations / day")
plt.show()
# title = "Basic SIR"
# steps = 180
# kwargs = {}
# twinx = False
x = np.arange(steps)
y30 = run(30, steps=steps, **kwargs)
y800 = run(800, steps=steps, **kwargs)
t30 = y30.sum()
t800 = y800.sum()
print(f"total from 30 cases = {int(t30)}")
print(f"total from 800 cases = {int(t800)}")
print(f"More by a factor of {t800 / t30:.3f}")
if twinx:
plt.figure()
plt.title(title)
ax800 = plt.gca()
ax30 = ax800.twinx()
(h800,) = ax800.plot(x, y800, color="red", label="800 initial cases")
(h30,) = ax30.plot(x, y30, color="blue", label="30 initial cases")
ax800.tick_params(axis="y", labelcolor="red")
ax30.tick_params(axis="y", labelcolor="blue")
ax800.set_ylabel("Daily infections")
plt.legend(handles=(h800, h30))
plt.figure()
plt.plot(x, y800, label="800 initial cases", color="red")
plt.plot(x, y30, label="30 initial cases", color="blue")
plt.title(title)
plt.xlabel("Day")
plt.ylabel("Daily infections")
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment