Skip to content

Instantly share code, notes, and snippets.

@Caellwyn
Last active January 8, 2021 17:54
Show Gist options
  • Save Caellwyn/4e6b4d680d9f698278f691922fa787e5 to your computer and use it in GitHub Desktop.
Save Caellwyn/4e6b4d680d9f698278f691922fa787e5 to your computer and use it in GitHub Desktop.
SARIMA_model_demonstration.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "SARIMA_model_demonstration.ipynb",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOq06aUG9iHK/hZpVb9X31M",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Caellwyn/4e6b4d680d9f698278f691922fa787e5/sarima_model_demonstration.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ebqh3gKgz6SL"
},
"source": [
"from statsmodels.tsa.statespace.sarimax import SARIMAX\r\n",
"import numpy as np\r\n",
"import pandas as pd\r\n",
"from plotly import graph_objects as go\r\n",
"\r\n",
"def get_covid_data(country, region=None, \r\n",
" stats = ['ConfirmedCases', 'ConfirmedDeaths']):\r\n",
" #downloads data for 'country' and 'region' and returns a pandas dataframe\r\n",
" \r\n",
" DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'\r\n",
" df = pd.read_csv(DATA_URL,\r\n",
" parse_dates=['Date'],\r\n",
" encoding=\"ISO-8859-1\",\r\n",
" dtype={\"RegionName\": str,\r\n",
" \"RegionCode\": str,\r\n",
" \"CountryName\": str,\r\n",
" \"CountryCode\": str},\r\n",
" usecols = ['Date','CountryName','RegionName',\r\n",
" 'ConfirmedCases','ConfirmedDeaths','Jurisdiction'],\r\n",
" index_col = 'Date',\r\n",
" error_bad_lines=False)\r\n",
" df = df[df['CountryName'] == country]\r\n",
"\r\n",
" if region:\r\n",
" df = df[df['RegionName'] == region]\r\n",
" else: \r\n",
" df = df[df['Jurisdiction'] == 'NAT_TOTAL']\r\n",
" df = df[stats]\r\n",
" return df\r\n",
"\r\n",
"def predict_COVID(df, forecast_length=30):\r\n",
" #returns a data frame of predictions for each columns in df\r\n",
" # out to 'forecast_length' time\r\n",
"\r\n",
" #transform the data by taking the cube root of each data point\r\n",
" cbrt_df = np.cbrt(df)\r\n",
" forecasts = pd.DataFrame()\r\n",
" #instantiate the model\r\n",
" for stat in df.columns:\r\n",
"\r\n",
" cbrt_df = np.cbrt(df[stat])\r\n",
" model = SARIMAX(cbrt_df, order = (0,2,0), seasonal_order = (3,2,1,7),\r\n",
" freq = 'D')\r\n",
" fit_model = model.fit(maxiter = 200, disp = False)\r\n",
" yhat = fit_model.forecast(forecast_length)**3\r\n",
" col = 'Forecasted_' + stat\r\n",
" forecasts[col] = yhat\r\n",
" return forecasts\r\n",
"\r\n",
"def plot_graph(df, country, region=None):\r\n",
" #plots 'stats' from 'df'. \r\n",
" #Uses 'country' and 'region' to make appropriate title\r\n",
" #returns a plotly figure\r\n",
"\r\n",
" if region:\r\n",
" title = f'Cumulative {\" and \".join(stats)} in {region}, {country}'\r\n",
" else:\r\n",
" title = f'Cumulative {\" and \".join(stats)} in {country}'\r\n",
"\r\n",
" fig = go.Figure(layout_title_text = title)\r\n",
" for stat in df.columns:\r\n",
" fig.add_trace(go.Scatter(x=df.index, y=df[stat], name=stat))\r\n",
" return fig\r\n",
"\r\n",
"country = 'United States'\r\n",
"region = None\r\n",
"forecast_length = 30\r\n",
"stats = ['ConfirmedDeaths', 'ConfirmedCases']\r\n",
"\r\n",
"df = get_covid_data(country=country, stats=stats)\r\n",
"forecasts = predict_COVID(df, forecast_length=forecast_length)\r\n",
"to_graph = pd.concat([df,forecasts], join='outer', axis=0)\r\n",
"plot_graph(to_graph, country=country, region=region)"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment