Skip to content

Instantly share code, notes, and snippets.

@YashasSamaga
Last active November 9, 2020 13:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YashasSamaga/fe04b91066655e4ab2d406af1f176838 to your computer and use it in GitHub Desktop.
Save YashasSamaga/fe04b91066655e4ab2d406af1f176838 to your computer and use it in GitHub Desktop.
covid_estimate_sum_of_gaussians.py
import requests
import numpy as np
import matplotlib.pyplot as plt
r = requests.get('https://api.covid19india.org/states_daily.json')
data = r.json()
print(len(data['states_daily']))
from datetime import datetime
date_format = "%Y-%m-%d"
start_date = datetime.strptime('2020-3-14', date_format)
end_date = datetime.strptime('2020-11-2', date_format)
duration = (end_date - start_date).days
states = ['an', 'ap', 'ar', 'as', 'br', 'ch', 'ct', 'dd', 'dl', 'dn', 'ga', 'gj', 'hp', 'hr', 'jh', 'jk', 'ka', 'kl', 'la', 'ld', 'mh', 'ml', 'mn', 'mp', 'mz', 'nl', 'or', 'pb', 'py', 'rj', 'sk', 'tg', 'tn', 'tr', 'tt', 'un', 'up', 'ut', 'wb']
stats = {}
for s in states:
stats[s] = [0] * duration
for row in data['states_daily']:
if row['status'] == 'Confirmed':
date = datetime.strptime(row['dateymd'], date_format)
idx = (date - start_date).days
if idx < duration:
for s in states:
stats[s][idx] += int(row[s])
def fit_gaussian(data):
def fit_function(x, A, mu, sigma):
z = (x - mu)/sigma
return A * np.exp(-z * z/2)
x = np.linspace(0, duration, duration)
from scipy.optimize import curve_fit
popt, _ = curve_fit(fit_function, x, stats[s])
return popt
class COVIDCurve():
def __init__(self, n_components):
self.n_components = n_components
@staticmethod
def fit_function(x, *args):
n_components = len(args) // 3
result = 0.0
for c in range(n_components):
A, mu, sigma = args[c * 3 : c * 3 + 3]
from scipy.stats import norm
result += A * norm.pdf(x, mu, sigma)
return result
def fit(self, data, add_extra_gaussian):
def moving_average(x, window_size = 3):
filter = [1/window_size] * window_size
reflected_data = np.append(x, (x[-1:-window_size])[::-1])
return np.convolve(reflected_data, filter, 'same')
smoothed_points = moving_average(data, 31)
from scipy.signal import find_peaks
peaks, peak_props = find_peaks(smoothed_points, prominence = 300, width = 20)
initial_weights = smoothed_points[peaks]
initial_means = peaks
initial_sigma = peak_props['widths']
if len(initial_weights) == 0:
initial_weights = [np.max(data)]
initial_means = [np.argmax(data)]
initial_sigma = [np.std(data)]
if add_extra_gaussian:
initial_weights = np.append(initial_weights, [np.max(data)])
initial_means = np.append(initial_means, [len(data)])
initial_sigma = np.append(initial_sigma, [np.std(data)])
assert(len(initial_weights) == self.n_components or self.n_components == None)
self.n_components = len(initial_weights)
print(peaks, peak_props)
print(initial_weights, initial_means, initial_sigma)
assert(self.n_components >= 1)
initials = []
for i in range(self.n_components):
initials += [initial_weights[i], initial_means[i], initial_sigma[i]]
x = np.linspace(0, len(data), len(data))
from scipy.optimize import curve_fit
popt, _ = curve_fit(self.fit_function, x, data, p0 = initials)
self.params = popt
def compute(self, x):
return self.fit_function(x, *self.params)
def plot(self, title, days):
x = np.linspace(0, days, days)
plt.plot(x, [self.compute(x) for x in x])
plt.title(title)
plt.show()
def plot_aggregate(models, days = 365, title = ""):
x = np.linspace(0, days, days)
y = np.zeros(days)
for _, model in models.items():
for i in range(days):
y[i] += model.compute(i)
plt.plot(x, y)
plt.title(title)
plt.show()
models = {}
for s in states:
print("Fitting ", s)
if np.max(stats[s]) < 1000:
print("\tIgnoring due to lack of data")
continue
if np.min(stats[s]) < 0:
print("\tIgnoring due to unreliable data")
continue
m = COVIDCurve(None)
m.fit(stats[s], True)
models[s] = m
x = np.linspace(0, duration, duration)
plt.plot(x, stats[s])
plt.title(s + " original")
plt.show()
m.plot(s + " predicted", duration)
plot_aggregate(models, 830, "India Daily Case Count")
from datetime import timedelta
print(start_date + timedelta(days=182))
print(start_date + timedelta(days=301))
print(start_date + timedelta(days=550))
import requests
r = requests.get('https://api.covid19india.org/states_daily.json')
data = r.json()
print(len(data['states_daily']))
from datetime import datetime
date_format = "%Y-%m-%d"
start_date = datetime.strptime('2020-3-14', date_format)
end_date = datetime.strptime('2020-11-2', date_format)
duration = (end_date - start_date).days
states = ['an', 'ap', 'ar', 'as', 'br', 'ch', 'ct', 'dd', 'dl', 'dn', 'ga', 'gj', 'hp', 'hr', 'jh', 'jk', 'ka', 'kl', 'la', 'ld', 'mh', 'ml', 'mn', 'mp', 'mz', 'nl', 'or', 'pb', 'py', 'rj', 'sk', 'tg', 'tn', 'tr', 'tt', 'un', 'up', 'ut', 'wb']
stats = {}
for s in states:
stats[s] = [0] * duration
for row in data['states_daily']:
if row['status'] == 'Confirmed':
date = datetime.strptime(row['dateymd'], date_format)
idx = (date - start_date).days
if idx < duration:
for s in states:
stats[s][idx] += int(row[s])
def fit_function(x, A, mu, sigma):
z = (x - mu)/sigma
return A * np.exp(-z * z/2)
models = {}
import numpy as np
from scipy.optimize import curve_fit
for s in states:
estimates = [np.max(stats[s]), np.argmax(stats[s]), np.std(stats[s])]
x = np.linspace(0, duration, duration)
popt, _ = curve_fit(fit_function, x, stats[s], p0 = estimates)
models[s] = popt
import matplotlib.pyplot as plt
def plot(A, mu, sigma, days = 365, title = ""):
assert(A > 0)
assert(sigma > 0)
assert(mu > 0)
x = np.linspace(0, days, days)
plt.plot(x, [fit_function(x, A, mu, sigma) for x in x])
plt.title(title)
plt.show()
def plot_aggregate(days = 365, title = ""):
x = np.linspace(0, days, days)
y = [0] * days
for s in states:
A, mu, sigma = models[s]
if sigma < 1 or A < 100:
continue
for i in range(0, days):
y[i] += fit_function(i, A, mu, sigma)
plt.plot(x, y)
plt.title(title)
plt.show()
plot(*models['ka'], 830, "Karnataka Daily Case Count")
plot_aggregate(830, "India Daily Case Count")
from datetime import timedelta
print(start_date + timedelta(days=182))
print(start_date + timedelta(days=301))
print(start_date + timedelta(days=550))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment