Skip to content

Instantly share code, notes, and snippets.

@Caellwyn
Last active January 23, 2021 13:14
Show Gist options
  • Save Caellwyn/3517deb91ed2a8720266305aea4b6b79 to your computer and use it in GitHub Desktop.
Save Caellwyn/3517deb91ed2a8720266305aea4b6b79 to your computer and use it in GitHub Desktop.
fb_model.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "fb_model.ipynb",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNwxITfC7weYWxRkS7/4yOb",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Caellwyn/3517deb91ed2a8720266305aea4b6b79/fb_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "NjGd8SYr9k1w"
},
"source": [
"import pandas as pd\r\n",
"#!pip install fbprophet\r\n",
"from fbprophet import Prophet\r\n",
"import matplotlib.pyplot as plt\r\n",
"\r\n",
"division = 'country' #regional data is available for some countries\r\n",
"region = 'United States'\r\n",
"prediction = 'ConfirmedCases' #ConfirmedDeaths is also available for forecasting.\r\n",
"\r\n",
"#get the latest data from OxCGRT\r\n",
"DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'\r\n",
"full_df = pd.read_csv(DATA_URL,\r\n",
" usecols=['Date','CountryName','RegionName','Jurisdiction',\r\n",
" 'ConfirmedCases','ConfirmedDeaths'],\r\n",
" parse_dates=['Date'],\r\n",
" encoding=\"ISO-8859-1\",\r\n",
" dtype={\"RegionName\": str,\r\n",
" \"CountryName\":str})\r\n",
"\r\n",
"#Filter the region we want to predict\r\n",
"if division == 'country':\r\n",
" df = full_df[(full_df['Jurisdiction'] == 'NAT_TOTAL') \r\n",
" & (full_df['CountryName'] == region)][:-1]\r\n",
"elif division == 'state':\r\n",
" df = full_df[(full_df['Jurisdiction'] == 'STATE_TOTAL') \r\n",
" & (full_df['RegionName'] == region)][:-1]\r\n",
"\r\n",
"#Since we are not using exogenous variables, we just keep the dates and endogenous data\r\n",
"df = df[['Date',prediction]].rename(columns = {'Date':'ds', prediction:'y'})"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1s0h3Trm9xJD"
},
"source": [
"# set how many days to forecast\r\n",
"forecast_length = 30\r\n",
"# instantiate and fit the model\r\n",
"m = Prophet()\r\n",
"m.fit(df)\r\n",
"# create the prediction dataframe 'forecast_length' days past the fit data\r\n",
"future = m.make_future_dataframe(periods=forecast_length)\r\n",
"# make the forecast to the end of the 'future' dataframe\r\n",
"forecast = m.predict(future)\r\n",
"\r\n",
"to_plot = forecast[forecast.ds > '2020-12-01'].merge(df, how='left')\r\n",
"\r\n",
"plt.figure(figsize = (10,7))\r\n",
"plt.plot(to_plot['ds'], to_plot['yhat'], label='Forecasted Cases')\r\n",
"plt.plot(to_plot['ds'], to_plot['y'], label='True Cases')\r\n",
"plt.fill_between(to_plot['ds'], to_plot['yhat_upper'], to_plot['yhat_lower'],\r\n",
" alpha=.2, label='Confidence')\r\n",
"plt.title('Facebook Prophet Forecasted COVID-19 cases, 1-22-2021 to 2-20-2021')\r\n",
"plt.legend()\r\n",
"plt.savefig('prophet_forecast.png')\r\n",
"plt.show()\r\n",
"print('\\n The \"forecast\" DataFrame \\n')\r\n",
"forecast"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment