Instantly share code, notes, and snippets.
Last active
March 15, 2020 08:59
-
Save alexyakunin/5ef2f4cbedf1da5c2c90633bc7ce598e to your computer and use it in GitHub Desktop.
An attempt to estimate COVID-19 transmission rate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def clamp(x, x_min, x_max): | |
return max(x_min, min(x_max, x)) | |
def safe_segment(a, start=0, end=None): | |
if end is None: | |
end = len(a) | |
start = clamp(start, 0, len(a)) | |
end = clamp(end, start, len(a)) | |
return a[start:end] | |
def safe_sum(a, start=0, end=None): | |
return sum(safe_segment(a, start, end)) | |
class State(object): | |
report_period = 7 | |
transmission_rate = 2.5 # Just ignore it - the code below overwrites it anyway | |
transmission_period = 7 | |
death_rate = 0.05 | |
death_period = 12 | |
cure_period = 28 | |
def __init__(self): | |
self.infections = [] | |
self.deaths = [] | |
def __str__(self): | |
ri = int(self.reported_infected_count()) | |
i = int(self.infected_count()) | |
c = int(self.cured_count()) | |
d = int(self.dead_count()) | |
return f'Day #{self.day_index()}: reported={ri} cured={c} died={d} (infected={i})' | |
def day_index(self): | |
return len(self.infections) | |
def daily_transmission_rate(self): | |
return pow(self.transmission_rate + 1, 1.0 / self.transmission_period) - 1 | |
def contagious_count(self): | |
return safe_sum(self.infections, len(self.infections) - self.transmission_period) | |
def reported_infected_count(self): | |
return safe_sum(self.infections, 0, len(self.infections) - self.report_period) | |
def infected_count(self): | |
return sum(self.infections) | |
def sick_count(self): | |
infected_count_28 = safe_sum(self.infections, len(self.infections) - self.cure_period) | |
dead_count_14 = safe_sum(self.infections, len(self.infections) - (self.cure_period - self.death_period)) | |
return infected_count_28 - dead_count_14 | |
def dead_count(self): | |
return sum(self.deaths) | |
def cured_count(self): | |
max_day = len(self.infections) - self.cure_period - 1 | |
cured_count = safe_sum(self.infections, 0, max_day) | |
dead_count = safe_sum(self.deaths, 0, max_day) | |
return cured_count - dead_count | |
def next_state(self, index=1): | |
if index == 0: | |
return self | |
if index > 1: | |
s = self.next_state(index - 1) | |
return s.next_state() | |
death_pool_day = self.day_index() - self.death_period | |
death_pool_count = safe_sum(self.infections, death_pool_day, death_pool_day + 1) | |
new_infections = self.contagious_count() * self.daily_transmission_rate() | |
new_deaths = death_pool_count * self.death_rate | |
r = State() | |
r.infections = self.infections + [new_infections] | |
r.deaths = self.deaths + [new_deaths] | |
return r | |
def find_transmission_rate(): | |
min_rate = 1 | |
max_rate = 1000 | |
while (max_rate - min_rate) > 0.00001: | |
next_rate = (min_rate + max_rate) / 2 | |
State.transmission_rate = next_rate | |
s = State() | |
s.infections = [1] | |
s.deaths = [0] | |
while s.reported_infected_count() < 68 and s.day_index() < 200: | |
s = s.next_state() | |
e = s.next_state(14) | |
if e.reported_infected_count() < 3000: | |
min_rate = next_rate | |
else: | |
max_rate = next_rate | |
return s, e | |
def find_transmission_and_death_rate(): | |
min_rate = 0.001 | |
max_rate = 0.3 | |
while (max_rate - min_rate) > 0.001: | |
next_rate = (min_rate + max_rate) / 2 | |
State.death_rate = next_rate | |
s, e = find_transmission_rate() | |
if e.dead_count() < 60: | |
min_rate = next_rate | |
else: | |
max_rate = next_rate | |
return s, e | |
def model_case(title, report_period, transmission_period, cure_period, death_period): | |
print(f'"{title}" case:') | |
print(f' report_period: {report_period}') | |
print(f' transmission_period: {transmission_period}') | |
print(f' cure_period: {cure_period}') | |
print(f' death_period: {death_period}') | |
State.report_period = report_period | |
State.transmission_period = transmission_period | |
State.cure_period = cure_period | |
State.death_period = death_period | |
s, e = find_transmission_and_death_rate() | |
print(f'Computed rates:') | |
print(f' transmission_rate: {State.transmission_rate}') | |
print(f' death_rate: {State.death_rate}') | |
print() | |
print('Modelling Feb 29 ... Mar 14:') | |
while s.reported_infected_count() <= 3100: | |
print(s) | |
s = s.next_state() | |
print() | |
model_case('Pessimistic (=15)', 10, 10, 24, 12) | |
model_case('Realistic (=7.5)', 7, 7, 21, 12) | |
model_case('Ok-ish (=2.7)', 3, 3, 17, 12) | |
model_case('Official (=2)', 2, 2, 17, 12) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output: