Skip to content

Instantly share code, notes, and snippets.

@nielskou
Created March 31, 2020 13:05
Show Gist options
  • Save nielskou/1bf7da7f37259cdf459fa1ea07b15893 to your computer and use it in GitHub Desktop.
Save nielskou/1bf7da7f37259cdf459fa1ea07b15893 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pylab as pl
from scipy.optimize import curve_fit
import requests
import os.path
data_file = 'key-countries-pivoted.csv'
if not os.path.exists(data_file):
url = "https://raw.githubusercontent.com/datasets/covid-19/master/data/key-countries-pivoted.csv"
myfile = requests.get(url)
with open(data_file, 'wb') as file:
file.write(myfile.content)
def func(x, a0, a1, a2):
return (a0 * (np.tanh((x + a1) * a2) + 1) )
def fit(func, x, y, popt=[50000, -30, 0.1]):
try:
param_bounds = ([0, -60, 0], [1e6, -10, 1.])
popt1, pcov = curve_fit(func, x, y, p0=popt, bounds=param_bounds)
return popt1, pcov
except:
print('Fit failed!')
return [[0,0,0], None]
with open(data_file, 'rt') as file:
names = [name.strip() for name in file.readline().split(',')]
data = np.genfromtxt(data_file, delimiter=',', skip_header=1, names=names)
count = 0
date = 'Mar 31, 2020'
for name in names[1:]:
y = data[name]
ind = np.where(y < 100)
y = np.delete(y, ind)
len_data = len(y)
x = np.arange(len_data)
popt1, pcov = fit(func, x, y)
delta_hd = np.round(len(y) - 1 + popt1[1], 1)
print(name, 'delta hump day: ', delta_hd)
x_off = popt1[1]
x_fit = np.linspace(0, 120, 100)
pl.subplot(2, 4, count + 1)
pl.plot([0, 0], [0, 1000000], 'r--', alpha=0.5)
pl.plot([-30, 30], [2 * popt1[0], 2 * popt1[0]], 'r--', alpha=0.5)
pl.plot(x_fit + x_off, func(x_fit, *popt1), label='Fit')
for i in range(30):
# randomly sample variantions on the fitted function using the covariance matrix
sample = np.random.multivariate_normal(popt1, pcov)
pl.plot(x_fit + x_off, func(x_fit, *sample), label=None, color='gray', alpha=0.1)
pl.plot(x + x_off, y, 'x', color='r', label='Data')
conv = 'high'
if delta_hd < 2:
conv = 'low'
if 2 <= delta_hd <= 10:
conv = 'medium'
textstr = '\n'.join((name, r'Delta hump day: ' + str(delta_hd), 'Confidence: ' + conv, str(date)))
props = dict(boxstyle='round', facecolor='w', alpha=0.8)
pl.text(-25, 162000, textstr, fontsize=10, bbox=props)
if not ( count == 0 or count == 4 ):
pl.yticks([])
else:
pl.ylabel('Number of confirmed cases')
pl.xlim([-27, 30])
pl.ylim([0, 300000])
pl.xlabel('Delta hump day (d)')
pl.legend(loc=2)
count += 1
pl.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment