Skip to content

Instantly share code, notes, and snippets.

@jstults
Created November 20, 2020 19:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jstults/22d89987e24888e9c06f8ae9e7584fef to your computer and use it in GitHub Desktop.
Save jstults/22d89987e24888e9c06f8ae9e7584fef to your computer and use it in GitHub Desktop.
Vector autoregression model fit to COVID-19 data
#
# filename: varcovid.py
# author: josh stults
# date: 20 Nov 2020
#
# fit some autoregressive models to deaths and cases
#
# data from:
# https://covid.ourworldindata.org/data/owid-covid-data.csv
#
# inspiration from:
# https://www.theatlantic.com/science/archive/2020/11/coronavirus-death-rate-third-surge/617150/
# one of the researchers quoted in that article found a lag between
# deaths and cases that maximized linear correlation between the time
# series; pretty crude (poor practice to fit to averaged/smoothed
# data), but interesting
#
# documentation of the vector auto regression model:
# https://www.statsmodels.org/dev/generated/statsmodels.tsa.vector_ar.var_model.VAR.html#statsmodels.tsa.vector_ar.var_model.VAR
#
# numerical tools
import numpy as np
# data management tools
import pandas as pd
# statistical models
import statsmodels.api as sm
from statsmodels.tsa.api import VAR
# visualization tools
from matplotlib import pyplot as plt
import seaborn as sns
# XXX codes = pd.read_csv("owid/owid-covid-codebook.csv") XXX #
data = pd.read_csv("owid/owid-covid-data.csv")
usdata = data.loc[data["location"]=="United States"][["new_cases","new_deaths"]]
usdata.index = pd.DatetimeIndex(data.loc[data["location"]=="United States"]["date"])
recent = usdata['20200701':'20201120'] # limit model to more recent data
model = VAR(recent)
results = model.fit(maxlags=20, ic='aic')
aiclag = results.k_ar
print("kept lags up to %d" % aiclag)
np = 30 # predict 30 days ahead
predict = results.forecast(usdata.values[-aiclag:], np)
predict = pd.DataFrame(predict, columns=["new_cases","new_deaths"], index=pd.date_range('2020-11-21', periods=np, freq='D'))
# visualizations
plt.figure()
plt.plot(usdata.index.to_pydatetime(),
usdata["new_cases"], 'o', label="New Cases")
plt.plot(predict.index.to_pydatetime(), predict["new_cases"], '-', label="VAR(%d) Forecast, fit 07-01 to 11-20" % aiclag)
plt.legend(loc=0)
plt.title("Daily New U.S. Covid-19 Cases")
plt.text('20200101',0,"https://covid.ourworldindata.org/data/owid-covid-data.csv")
plt.savefig("us-new-cases.png", bbox_inches="tight")
plt.figure()
plt.plot(usdata.index.to_pydatetime(),
usdata["new_deaths"], 'o', label="New Deaths")
plt.plot(predict.index.to_pydatetime(), predict["new_deaths"], '-', label="VAR(%d) Forecast, fit 07-01 to 11-20" % aiclag)
plt.legend(loc=0)
plt.title("Daily New U.S. Covid-19 Deaths")
plt.text('20200101',0,"https://covid.ourworldindata.org/data/owid-covid-data.csv")
plt.savefig("us-new-deaths.png", bbox_inches="tight")
@jstults
Copy link
Author

jstults commented Nov 20, 2020

us-new-cases

@jstults
Copy link
Author

jstults commented Nov 20, 2020

us-new-deaths

@jstults
Copy link
Author

jstults commented Nov 21, 2020

us-new-tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment