Skip to content

Instantly share code, notes, and snippets.

@ghamerly
Last active January 12, 2021 12:36
Show Gist options
  • Save ghamerly/723c6bad926d6c1523c094a3b6a3eb6a to your computer and use it in GitHub Desktop.
Save ghamerly/723c6bad926d6c1523c094a3b6a3eb6a to your computer and use it in GitHub Desktop.
*.csv
*.png
*.pdf
*.svg
#!/usr/bin/env python3
'''
This script uses data from https://github.com/CSSEGISandData/COVID-19/ and:
- grabs it (using the requests package)
- filters it (using the command line arguments and only a recent number of days)
- fits an exponential growth model to the data,
- plots the data and the growth curve,
- makes predictions of the numbers using the growth curve.
It can plot each type of data found under the subdirectory
csse_covid_19_data/csse_covid_19_time_series (i.e. per-country, or
per-USA-county).
'''
import argparse
import csv
import hashlib
import os
import sys
import time
import matplotlib.pyplot
import requests
import numpy.linalg
class Constants: # pylint: disable=too-few-public-methods
'''Gathering place for data needed elsewhere.'''
_GLOBAL_HEADERS = {
'major': 'Country/Region',
'major_default': 'US',
'minor': 'Province/State',
}
_USA_HEADERS = {
'major': 'Province_State',
'major_default': None,
'minor': 'Admin2',
}
# split long string over multiple lines...
DATA_URL_BASE = (
'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/'
'master/csse_covid_19_data/csse_covid_19_time_series/')
DATA_OPTIONS = {
'deaths': {'file': 'time_series_covid19_deaths_global.csv',
'headers': _GLOBAL_HEADERS,
'description': 'deaths'},
'infections': {'file': 'time_series_covid19_confirmed_global.csv',
'headers': _GLOBAL_HEADERS,
'description': 'infections'},
'recovered': {'file': 'time_series_covid19_recovered_global.csv',
'headers': _GLOBAL_HEADERS,
'description': 'recovered'},
'usa_infections': {'file': 'time_series_covid19_confirmed_US.csv',
'headers': _USA_HEADERS,
'description': 'infections'},
'usa_deaths': {'file': 'time_series_covid19_deaths_US.csv',
'headers': _USA_HEADERS,
'description': 'deaths'},
}
def parse_args():
'''Parse the command line arguments (and create extra fields "url" and
"headers").'''
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--major', default='', help='Major region (country or USA state)')
parser.add_argument('--minor', default='', help='Minor region (state/province or USA county)')
parser.add_argument('--days', type=int, default=20, help='How many days to look at')
parser.add_argument('--test', type=int, default=0, \
help='How many recent examples to use for testing (and not train)')
parser.add_argument('--which', choices=list(Constants.DATA_OPTIONS), default='infections', \
help='Which series to plot')
parser.add_argument('--linear_y', action='store_true', default=False, \
help='Use linear scale for y-axis')
parser.add_argument('--save', help='Filename of saved plot (if not specified, do not save)')
args = parser.parse_args()
args.url = Constants.DATA_URL_BASE + Constants.DATA_OPTIONS[args.which]['file']
args.headers = Constants.DATA_OPTIONS[args.which]['headers']
args.description = Constants.DATA_OPTIONS[args.which]['description']
args.major = args.major or args.headers['major_default']
return args
def get_data(args):
'''Get the data of interest'''
raw_csv = None
cached_data_file = 'cache_' + hashlib.md5(args.url.encode('utf-8')).hexdigest() + '.csv'
try:
s = os.stat(cached_data_file)
print(s)
if time.time() - s.st_mtime < 10 * 60: # cache data for 10 minutes
print('reading cached data from', cached_data_file)
raw_csv = open(cached_data_file).read()
else:
print('cached data is too old')
except Exception as e:
print('exception while getting cached data', e)
if not raw_csv:
print(f'retrieving data from {args.url}')
raw_csv = requests.get(args.url).text
with open(cached_data_file, 'w') as cache:
cache.write(raw_csv)
reader = csv.DictReader(list(raw_csv.split('\n')))
for row in reader:
if row[args.headers['major']].lower() == args.major.lower() and \
row[args.headers['minor']].lower() == args.minor.lower():
return row
return None
def fit_exponential_model(args, log_train):
'''Fit a linear model to the log-data, i.e. log(y) ~ theta[0] + theta[1] * x'''
# data matrix
x_matrix = numpy.array([[1, i] for i in range(args.days - args.test)])
x_t_x = numpy.matmul(numpy.transpose(x_matrix), x_matrix)
xtx_inv_xt = numpy.matmul(numpy.linalg.pinv(x_t_x), numpy.transpose(x_matrix))
theta = numpy.matmul(xtx_inv_xt, numpy.transpose(log_train))
return theta
def main():
'''Parse the command line arguments, request the data, and plot it.'''
args = parse_args()
# get the data of interest
counts = get_data(args)
if counts is None:
print("Could not find the data you were looking for... bailing.")
return
# construct the data
recent = list(map(int, list(counts.values())[-args.days:]))
train = recent[:len(recent)-args.test]
test = recent[len(recent)-args.test:]
safe_log = lambda x: numpy.log(x) if x else -1
log_train = list(map(safe_log, train)) # this is our "y"
log_test = list(map(safe_log, test))
print('train', len(log_train), log_train)
print('test', len(log_test), log_test)
theta = fit_exponential_model(args, log_train)
print('estimated parameters', theta)
# compute and print the % daily growth and predictions
growth_str = make_predictions(args, theta, recent)
y_data = (train + test) if args.linear_y else (log_train + log_test)
plot_data(args, counts, theta, y_data, growth_str)
def make_predictions(args, theta, recent):
'''Compute the daily growth rate (according to the model) and make
predictions for the next 30 days based on that; return a string description
of that growth rate.'''
growth = numpy.exp(theta[1])
growth_str = '{:0.1f}'.format((growth - 1) * 100)
print('the most recent count of {} is {}'.format(args.description, recent[-1]))
print('the {} are growing {}% each day'.format(args.description, growth_str))
for day in range(30):
prediction = int((growth ** day) * recent[-1])
print('in {} days, that means {} {}'.format(day, prediction, args.description))
return growth_str
def plot_data(args, counts, theta, y_data, growth_str):
'''Plot the data and model using matplotlib'''
y_hat = [theta[0] + theta[1] * x for x in range(args.days)]
if args.linear_y:
y_hat = numpy.exp(y_hat)
matplotlib.pyplot.plot(range(args.days), y_data, 'x-')
matplotlib.pyplot.plot(range(args.days), y_hat)
# make the labels + title + legend
x_ticks = list(range(args.days))
matplotlib.pyplot.xticks(x_ticks[::2], x_ticks[::-2])
matplotlib.pyplot.xlabel('days ago')
if not args.linear_y:
y_min, y_max = matplotlib.pyplot.ylim()
y_ticks = []
for i in range(10):
if y_min <= (i + 1) * numpy.log(10) and (i - 1) * numpy.log(10) <= y_max:
y_ticks.append((i * numpy.log(10), 10 ** i))
matplotlib.pyplot.yticks([y[0] for y in y_ticks], [y[1] for y in y_ticks])
matplotlib.pyplot.ylabel('{} {}'.format(args.description, '' if args.linear_y else ' (log scale)'))
most_recent_date = list(counts)[-1]
minor_name = ''
if counts[args.headers['minor']]:
minor_name = ' (' + counts[args.headers['minor']] + ')'
name = '{}{}'.format(counts[args.headers['major']], minor_name)
matplotlib.pyplot.title('COVID-19 {} in {}: {}% daily growth (as of {})'.format( \
args.description, name, growth_str, most_recent_date))
legend = ['reported cases', 'model fit']
matplotlib.pyplot.axvspan(0, args.days - args.test - 1, facecolor='y', alpha=0.3)
legend.append('training region')
if args.test:
matplotlib.pyplot.axvspan(args.days - args.test - 1, args.days - 1, facecolor='g', alpha=0.3)
legend.append('prediction region')
matplotlib.pyplot.legend(legend)
# make sure all the labels fit on the plot
matplotlib.pyplot.tight_layout()
if args.save:
print('saving to', args.save)
matplotlib.pyplot.savefig(args.save, dpi=300)
matplotlib.pyplot.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment