Last active
January 8, 2021 17:54
-
-
Save Caellwyn/4e6b4d680d9f698278f691922fa787e5 to your computer and use it in GitHub Desktop.
SARIMA_model_demonstration.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "SARIMA_model_demonstration.ipynb", | |
"private_outputs": true, | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOq06aUG9iHK/hZpVb9X31M", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Caellwyn/4e6b4d680d9f698278f691922fa787e5/sarima_model_demonstration.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ebqh3gKgz6SL" | |
}, | |
"source": [ | |
"from statsmodels.tsa.statespace.sarimax import SARIMAX\r\n", | |
"import numpy as np\r\n", | |
"import pandas as pd\r\n", | |
"from plotly import graph_objects as go\r\n", | |
"\r\n", | |
"def get_covid_data(country, region=None, \r\n", | |
" stats = ['ConfirmedCases', 'ConfirmedDeaths']):\r\n", | |
" #downloads data for 'country' and 'region' and returns a pandas dataframe\r\n", | |
" \r\n", | |
" DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'\r\n", | |
" df = pd.read_csv(DATA_URL,\r\n", | |
" parse_dates=['Date'],\r\n", | |
" encoding=\"ISO-8859-1\",\r\n", | |
" dtype={\"RegionName\": str,\r\n", | |
" \"RegionCode\": str,\r\n", | |
" \"CountryName\": str,\r\n", | |
" \"CountryCode\": str},\r\n", | |
" usecols = ['Date','CountryName','RegionName',\r\n", | |
" 'ConfirmedCases','ConfirmedDeaths','Jurisdiction'],\r\n", | |
" index_col = 'Date',\r\n", | |
" error_bad_lines=False)\r\n", | |
" df = df[df['CountryName'] == country]\r\n", | |
"\r\n", | |
" if region:\r\n", | |
" df = df[df['RegionName'] == region]\r\n", | |
" else: \r\n", | |
" df = df[df['Jurisdiction'] == 'NAT_TOTAL']\r\n", | |
" df = df[stats]\r\n", | |
" return df\r\n", | |
"\r\n", | |
"def predict_COVID(df, forecast_length=30):\r\n", | |
" #returns a data frame of predictions for each columns in df\r\n", | |
" # out to 'forecast_length' time\r\n", | |
"\r\n", | |
" #transform the data by taking the cube root of each data point\r\n", | |
" cbrt_df = np.cbrt(df)\r\n", | |
" forecasts = pd.DataFrame()\r\n", | |
" #instantiate the model\r\n", | |
" for stat in df.columns:\r\n", | |
"\r\n", | |
" cbrt_df = np.cbrt(df[stat])\r\n", | |
" model = SARIMAX(cbrt_df, order = (0,2,0), seasonal_order = (3,2,1,7),\r\n", | |
" freq = 'D')\r\n", | |
" fit_model = model.fit(maxiter = 200, disp = False)\r\n", | |
" yhat = fit_model.forecast(forecast_length)**3\r\n", | |
" col = 'Forecasted_' + stat\r\n", | |
" forecasts[col] = yhat\r\n", | |
" return forecasts\r\n", | |
"\r\n", | |
"def plot_graph(df, country, region=None):\r\n", | |
" #plots 'stats' from 'df'. \r\n", | |
" #Uses 'country' and 'region' to make appropriate title\r\n", | |
" #returns a plotly figure\r\n", | |
"\r\n", | |
" if region:\r\n", | |
" title = f'Cumulative {\" and \".join(stats)} in {region}, {country}'\r\n", | |
" else:\r\n", | |
" title = f'Cumulative {\" and \".join(stats)} in {country}'\r\n", | |
"\r\n", | |
" fig = go.Figure(layout_title_text = title)\r\n", | |
" for stat in df.columns:\r\n", | |
" fig.add_trace(go.Scatter(x=df.index, y=df[stat], name=stat))\r\n", | |
" return fig\r\n", | |
"\r\n", | |
"country = 'United States'\r\n", | |
"region = None\r\n", | |
"forecast_length = 30\r\n", | |
"stats = ['ConfirmedDeaths', 'ConfirmedCases']\r\n", | |
"\r\n", | |
"df = get_covid_data(country=country, stats=stats)\r\n", | |
"forecasts = predict_COVID(df, forecast_length=forecast_length)\r\n", | |
"to_graph = pd.concat([df,forecasts], join='outer', axis=0)\r\n", | |
"plot_graph(to_graph, country=country, region=region)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment