Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dantonnoriega/092d04deba671f65b7a99a244c957c9d to your computer and use it in GitHub Desktop.
Save dantonnoriega/092d04deba671f65b7a99a244c957c9d to your computer and use it in GitHub Desktop.
A simple example of a difference-in-difference analysis in python using simulated data
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
# Set random seed for reproducibility
np.random.seed(42)
# Generate synthetic data
n_obs = 100
time = np.arange(n_obs)
treatment = np.concatenate((np.zeros(n_obs // 2), np.ones(n_obs // 2)))
time_treatment = time * treatment
control_trend = 1 + 0.1 * time + np.random.normal(0, .2, n_obs)
treatment_trend = 3 + .13 * time + .13 * np.maximum(0, time - sum(treatment)) * treatment + np.random.normal(0, .2, n_obs)
intervention_time = n_obs // 2 # Intervention at the middle
# Create a DataFrame
data = pd.DataFrame({
'time': time,
'time_treatment': time_treatment,
'treatment': treatment,
'control_trend': control_trend,
'treatment_trend': treatment_trend
})
# Define the outcome variable
data['outcome'] = data['control_trend'] + data['treatment_trend']
data.loc[data['time'] >= intervention_time, 'outcome'] += 2 # Effect of intervention
# Run difference-in-differences regression
## y = t + d + t:d
model = sm.OLS(data['outcome'], sm.add_constant(data[['time', 'treatment', 'time_treatment']]))
results = model.fit()
# Print regression results
print(results.summary())
# Create a plot
plt.figure(figsize=(10, 6))
plt.plot(data['time'], data['control_trend'], label='Control Trend')
plt.plot(data['time'], data['treatment_trend'], label='Treatment Trend')
plt.axvline(x=intervention_time, color='gray', linestyle='--', label='Intervention Time')
plt.annotate('Intervention', xy=(intervention_time, 3.5), xytext=(intervention_time + 5, 4.5),
arrowprops=dict(arrowstyle='->'), fontsize=12)
plt.xlabel('Time')
plt.ylabel('Trends')
plt.title('Near-Parallel Trends Before Intervention')
plt.legend()
plt.grid(True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment