Skip to content

Instantly share code, notes, and snippets.

@airalcorn2
Last active June 20, 2021 12:21
Show Gist options
  • Save airalcorn2/2ff737106ee02bf103d7cac703558f06 to your computer and use it in GitHub Desktop.
Save airalcorn2/2ff737106ee02bf103d7cac703558f06 to your computer and use it in GitHub Desktop.
Testing out the synthetic control approach used in Dave et al. (2020).
# See: https://matheusfacure.github.io/python-causality-handbook/15-Synthetic-Control.html
# and: http://ftp.iza.org/dp13670.pdf.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from datetime import datetime, timedelta
from scipy.optimize import fmin_slsqp
from toolz import partial
state_pops = {
"AL": 4_903_185,
"AK": 731_545,
"AZ": 7_278_717,
"AR": 3_017_825,
"CA": 39_512_223,
"CO": 5_758_736,
"CT": 3_565_287,
"DE": 973_764,
"FL": 21_477_737,
"GA": 10_617_423,
"HI": 1_415_872,
"ID": 1_787_065,
"IL": 12_671_821,
"IN": 6_732_219,
"IA": 3_155_070,
"KS": 2_913_314,
"KY": 4_467_673,
"LA": 4_648_794,
"ME": 1_344_212,
"MD": 6_045_680,
"MA": 6_949_503,
"MI": 9_986_857,
"MN": 5_639_632,
"MS": 2_976_149,
"MO": 6_137_428,
"MT": 1_068_778,
"NE": 1_934_408,
"NV": 3_080_156,
"NH": 1_359_711,
"NJ": 8_882_190,
"NM": 2_096_829,
"NY": 19_453_561,
"NC": 10_488_084,
"ND": 762_062,
"OH": 11_689_100,
"OK": 3_956_971,
"OR": 4_217_737,
"PA": 12_801_989,
"RI": 1_059_361,
"SC": 5_148_714,
"SD": 884_659,
"TN": 6_833_174,
"TX": 28_995_881,
"UT": 3_205_958,
"VT": 623_989,
"VA": 8_535_519,
"WA": 7_614_893,
"WV": 1_792_147,
"WI": 5_822_434,
"WY": 578_759,
}
# See footnote 17 in the paper.
def loss_w(W, X, y):
return np.mean(np.abs(y - X @ W))
# Download and prepare data.
# United States data.
states_url = "https://covidtracking.com/api/states/daily.csv"
df_states = pd.read_csv(states_url).sort_values("date")
df_states["date"] = df_states["date"].apply(
lambda x: datetime.strptime(str(x), "%Y%m%d")
)
df_states["population"] = df_states["state"].apply(
lambda state: state_pops.get(state, -1)
)
df_states["cases_per_1000"] = 1000 * df_states["positive"] / df_states["population"]
# Countries data.
countries_url = "https://covid.ourworldindata.org/data/owid-covid-data.csv"
countries_data = pd.read_csv(countries_url)
countries_data["cases_per_1000"] = countries_data["total_cases_per_million"] / 1000
# Settings.
# From paper.
sturgis_rally = datetime.strptime("20200803", "%Y%m%d")
prev_days = 28
start_date = sturgis_rally - timedelta(days=prev_days)
stop_date = datetime.strptime("20200902", "%Y%m%d")
target_state = "SD"
exclude_states = {"IA", "MN", "MT", "NE", "ND", "WY"}
outcome_var = "cases_per_1000"
target_country = "FRA" # ISO code.
# Build donor pool.
X = []
donor_states = list(set(state_pops) - exclude_states - {target_state})
donor_states.sort()
for state in donor_states:
df_state = df_states[df_states["state"] == state]
df_state = df_state[
(start_date <= df_state["date"]) & (df_state["date"] <= stop_date)
]
X.append(df_state[outcome_var].values)
X = np.stack(X).T
# State target.
df_state = df_states[df_states["state"] == target_state]
df_state = df_state[(start_date <= df_state["date"]) & (df_state["date"] <= stop_date)]
# Country target.
country_data = countries_data[countries_data["iso_code"] == target_country]
country_data["date"] = country_data["date"].apply(
lambda x: datetime.strptime(str(x), "%Y-%m-%d")
)
country_data = country_data[
(start_date <= country_data["date"]) & (country_data["date"] <= stop_date)
]
targets = {
target_state: df_state[outcome_var].values,
target_country: country_data[outcome_var].values,
}
for (locale, y) in targets.items():
print(f"\n{locale}\n")
weights = fmin_slsqp(
func=partial(loss_w, X=X[:prev_days], y=y[:prev_days]),
x0=np.array([1 / X.shape[1]] * X.shape[1]),
f_eqcons=lambda x: np.sum(x) - 1,
bounds=[(0.0, 1.0)] * X.shape[1],
disp=False,
)
sorted_idxs = np.argsort(-weights)
for idx in sorted_idxs:
print(f"{donor_states[idx]}: {weights[idx]:.4}")
# Figure 5 Panel (c): South Dakota.
sns.lineplot(x=df_state["date"], y=y, color="red")
fake_y = X @ weights
sns.lineplot(x=df_state["date"], y=fake_y, linestyle="--", color="blue")
plt.axvline(sturgis_rally, linestyle="--", color="red")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment