Created
November 20, 2020 19:20
-
-
Save jstults/22d89987e24888e9c06f8ae9e7584fef to your computer and use it in GitHub Desktop.
Vector autoregression model fit to COVID-19 data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# filename: varcovid.py | |
# author: josh stults | |
# date: 20 Nov 2020 | |
# | |
# fit some autoregressive models to deaths and cases | |
# | |
# data from: | |
# https://covid.ourworldindata.org/data/owid-covid-data.csv | |
# | |
# inspiration from: | |
# https://www.theatlantic.com/science/archive/2020/11/coronavirus-death-rate-third-surge/617150/ | |
# one of the researchers quoted in that article found a lag between | |
# deaths and cases that maximized linear correlation between the time | |
# series; pretty crude (poor practice to fit to averaged/smoothed | |
# data), but interesting | |
# | |
# documentation of the vector auto regression model: | |
# https://www.statsmodels.org/dev/generated/statsmodels.tsa.vector_ar.var_model.VAR.html#statsmodels.tsa.vector_ar.var_model.VAR | |
# | |
# numerical tools | |
import numpy as np | |
# data management tools | |
import pandas as pd | |
# statistical models | |
import statsmodels.api as sm | |
from statsmodels.tsa.api import VAR | |
# visualization tools | |
from matplotlib import pyplot as plt | |
import seaborn as sns | |
# XXX codes = pd.read_csv("owid/owid-covid-codebook.csv") XXX # | |
data = pd.read_csv("owid/owid-covid-data.csv") | |
usdata = data.loc[data["location"]=="United States"][["new_cases","new_deaths"]] | |
usdata.index = pd.DatetimeIndex(data.loc[data["location"]=="United States"]["date"]) | |
recent = usdata['20200701':'20201120'] # limit model to more recent data | |
model = VAR(recent) | |
results = model.fit(maxlags=20, ic='aic') | |
aiclag = results.k_ar | |
print("kept lags up to %d" % aiclag) | |
np = 30 # predict 30 days ahead | |
predict = results.forecast(usdata.values[-aiclag:], np) | |
predict = pd.DataFrame(predict, columns=["new_cases","new_deaths"], index=pd.date_range('2020-11-21', periods=np, freq='D')) | |
# visualizations | |
plt.figure() | |
plt.plot(usdata.index.to_pydatetime(), | |
usdata["new_cases"], 'o', label="New Cases") | |
plt.plot(predict.index.to_pydatetime(), predict["new_cases"], '-', label="VAR(%d) Forecast, fit 07-01 to 11-20" % aiclag) | |
plt.legend(loc=0) | |
plt.title("Daily New U.S. Covid-19 Cases") | |
plt.text('20200101',0,"https://covid.ourworldindata.org/data/owid-covid-data.csv") | |
plt.savefig("us-new-cases.png", bbox_inches="tight") | |
plt.figure() | |
plt.plot(usdata.index.to_pydatetime(), | |
usdata["new_deaths"], 'o', label="New Deaths") | |
plt.plot(predict.index.to_pydatetime(), predict["new_deaths"], '-', label="VAR(%d) Forecast, fit 07-01 to 11-20" % aiclag) | |
plt.legend(loc=0) | |
plt.title("Daily New U.S. Covid-19 Deaths") | |
plt.text('20200101',0,"https://covid.ourworldindata.org/data/owid-covid-data.csv") | |
plt.savefig("us-new-deaths.png", bbox_inches="tight") |
Author
jstults
commented
Nov 20, 2020
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment