Skip to content

Instantly share code, notes, and snippets.

@joelanders
Created March 11, 2021 10:52
Show Gist options
  • Save joelanders/e3a9f4f57fb9ebacc859d7f73595d46a to your computer and use it in GitHub Desktop.
Save joelanders/e3a9f4f57fb9ebacc859d7f73595d46a to your computer and use it in GitHub Desktop.
from datetime import datetime
from datetime import timedelta
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import statsmodels.api as sm
import math
# df = pd.read_csv('owid-covid-data.csv')
# df = pd.read_csv('vaccinations.csv')
df = pd.read_csv('vaccinations.csv')
df = df[df.date >= "2021-01-01 00:00:00.000"]
df = df[df.date <= "2021-03-12 00:00:00.000"]
df['day_of_year'] = [(datetime.strptime(d, "%Y-%m-%d") - datetime(2021,1,1)).days for d in df['date']]
# XXX sloppy way to set the x scale
x_days = 190
x_max_date = datetime(2021,1,1) + timedelta(days=x_days)
fig = go.Figure(
layout_yaxis_range=[0,100],
layout_xaxis_range=["2021-01-01 00:00:00.0000", x_max_date],
layout_yaxis_title_text="cumulative doses per hundred people",
layout_title_text="cumulative doses per hundred people",
layout_title_x=0.5,
)
fig.update_yaxes(nticks=20)
# for country_index, country in enumerate(['United Kingdom', 'United States', 'Italy', 'France', 'Germany']):
for country_index, country in enumerate(['United Kingdom', 'United States', 'Germany']):
country_df = df[df.location == country]
country_df = country_df.dropna(subset=['total_vaccinations_per_hundred'])
recent_country_df = country_df[df.date >= "2021-02-25 00:00:00.000"]
x = sm.add_constant(recent_country_df['day_of_year'])
model = sm.OLS(recent_country_df['total_vaccinations_per_hundred'], x)
results = model.fit()
print(results.params)
b = results.params[0]
m = results.params[1]
# scatter plot for data
fig.add_trace(
go.Scatter(
x=country_df['date'],
y=country_df['total_vaccinations_per_hundred'],
mode="markers",
marker_color=px.colors.qualitative.Dark2[country_index],
name=country,
showlegend=False,
)
)
# extrapolated line
fig.add_trace(
go.Scatter(
x=["2021-01-01 00:00:00.0000", x_max_date],
y=[results.params[0], (m*x_days + b)],
mode="lines",
marker_color=px.colors.qualitative.Dark2[country_index],
name=country,
showlegend=True,
)
)
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment