Skip to content

Instantly share code, notes, and snippets.

@apoorvalal
Last active May 25, 2024 16:05
Show Gist options
  • Save apoorvalal/e7dc9f3e52dcd9d51854b28b3e8a7ba4 to your computer and use it in GitHub Desktop.
Save apoorvalal/e7dc9f3e52dcd9d51854b28b3e8a7ba4 to your computer and use it in GitHub Desktop.
from joblib import Parallel, delayed
import numpy as np
import pandas as pd
class LinearMediation:
def __init__(self):
pass
def fit(self, X, W, y, store=True):
"""Fit Linear Mediation Model
Args:
X (2D Array): Treatment variable matrix (N x K)
W (2D Array): Mediator variable matrix (N x L)
y (1D Array): Outcome variable array (N x 1)
store (bool, optional): Store estimates in class? Defaults to True. Same method is used for bootstrapping with False.
"""
if store:
self.beta_tilde = np.linalg.lstsq(X, y, rcond=1)[0]
self.delta_tilde = np.linalg.lstsq(X, W, rcond=1)[0]
self.gamma_tilde = np.linalg.lstsq(W, y, rcond=1)[0]
self.total_effect, self.mediated_effect = self.beta_tilde, self.delta_tilde @ self.gamma_tilde
self.direct_effect = self.total_effect - self.mediated_effect
else:
beta_tilde = np.linalg.lstsq(X, y, rcond=1)[0]
delta_tilde = np.linalg.lstsq(X, W, rcond=1)[0]
gamma_tilde = np.linalg.lstsq(W, y, rcond=1)[0]
total_effect, mediated_effect = beta_tilde, delta_tilde @ gamma_tilde
direct_effect = total_effect - mediated_effect
return total_effect, mediated_effect, direct_effect
def bootstrap(self, B=1_000, alpha=0.05):
"""
Bootstrap Confidence Intervals for Total, Mediated and Direct Effects
"""
self.alpha = alpha
self.B = B
self._bootstrapped = Parallel(n_jobs=-1)(
delayed(self._bootstrap)() for _ in range(B)
)
self._bootstrapped = np.c_[self._bootstrapped]
self.ci = np.percentile(
self._bootstrapped, 100 * np.array([alpha / 2, 1 - alpha / 2]), axis=0
)
def summary(self):
"""
Summary Table for Total, Mediated and Direct Effects
"""
self.total_effects_summary = np.c_[self.total_effect, self.ci[:, : self.K].T]
self.mediated_effects_summary = np.c_[
self.mediated_effect, self.ci[:, (self.K) : (self.K + self.K)].T
]
self.direct_effects_summary = np.c_[self.direct_effect, self.ci[:, -self.K :].T]
# summmary table omits intercept and handles single treatment, else use *_effects_summary arrays yourself
self.summary_table = pd.DataFrame(
{
"Total Effect": self.total_effects_summary[1, :],
"Mediated Effect": self.mediated_effects_summary[1, :],
"Direct Effect": self.direct_effects_summary[1, :],
},
index=[
"Estimate",
f"CI Lower ({self.alpha/2})",
f"CI Upper ({1-self.alpha/2})",
],
)
return self.summary_table
def _bootstrap(self):
"""
one replication of bootstrap
"""
idx = np.random.choice(self.N, self.N)
X = self.X[idx]
W = self.W[idx]
y = self.y[idx]
return self.fit(X, W, y, store=False)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Linear Mediation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div id=\"NQ0FBq\"></div>\n",
" <script type=\"text/javascript\" data-lets-plot-script=\"library\">\n",
" if(!window.letsPlotCallQueue) {\n",
" window.letsPlotCallQueue = [];\n",
" }; \n",
" window.letsPlotCall = function(f) {\n",
" window.letsPlotCallQueue.push(f);\n",
" };\n",
" (function() {\n",
" var script = document.createElement(\"script\");\n",
" script.type = \"text/javascript\";\n",
" script.src = \"https://cdn.jsdelivr.net/gh/JetBrains/lets-plot@v4.2.0/js-package/distr/lets-plot.min.js\";\n",
" script.onload = function() {\n",
" window.letsPlotCall = function(f) {f();};\n",
" window.letsPlotCallQueue.forEach(function(f) {f();});\n",
" window.letsPlotCallQueue = [];\n",
" \n",
" };\n",
" script.onerror = function(event) {\n",
" window.letsPlotCall = function(f) {}; // noop\n",
" window.letsPlotCallQueue = [];\n",
" var div = document.createElement(\"div\");\n",
" div.style.color = 'darkred';\n",
" div.textContent = 'Error loading Lets-Plot JS';\n",
" document.getElementById(\"NQ0FBq\").appendChild(div);\n",
" };\n",
" var e = document.getElementById(\"NQ0FBq\");\n",
" e.appendChild(script);\n",
" })()\n",
" </script>\n",
" "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import pyfixest as pf"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from linear_mediation import LinearMediation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>deny</th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" </tr>\n",
" <tr>\n",
" <th>black</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.886854</td>\n",
" <td>0.113146</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.692308</td>\n",
" <td>0.307692</td>\n",
" </tr>\n",
" <tr>\n",
" <th>All</th>\n",
" <td>0.855726</td>\n",
" <td>0.144274</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"deny 0 1\n",
"black \n",
"0 0.886854 0.113146\n",
"1 0.692308 0.307692\n",
"All 0.855726 0.144274"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hmda = pd.read_csv(\"data/hmda_aer.csv\")\n",
"hmda[\"black\"] = np.where(hmda.s13 == 3, 1, 0)\n",
"hmda[\"deny\"] = np.where(hmda.s7 == 3, 1, 0)\n",
"pd.crosstab(hmda.black, hmda.deny, margins=True, normalize=\"index\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.19454619454619457"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pf.feols(\"deny ~ black\", data=hmda, vcov=\"HC1\").coef().iloc[1]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"df_small = (\n",
" hmda[[\"deny\", \"black\", \"s31a\", \"s42\", \"s43\", \"s44\", \"s25a\"]]\n",
" .query(\"s25a < 1000\")\n",
" .dropna()\n",
")\n",
"y = df_small.deny.values\n",
"X = np.c_[np.ones(df_small.shape[0]), df_small.black.values]\n",
"W = np.c_[np.ones(df_small.shape[0]), df_small.iloc[:, 2:].values]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Total Effect</th>\n",
" <th>Mediated Effect</th>\n",
" <th>Direct Effect</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Estimate</th>\n",
" <td>0.195167</td>\n",
" <td>0.079330</td>\n",
" <td>0.115837</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CI Lower (0.025)</th>\n",
" <td>0.150601</td>\n",
" <td>0.062152</td>\n",
" <td>0.075980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CI Upper (0.975)</th>\n",
" <td>0.241546</td>\n",
" <td>0.099950</td>\n",
" <td>0.154805</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Total Effect Mediated Effect Direct Effect\n",
"Estimate 0.195167 0.079330 0.115837\n",
"CI Lower (0.025) 0.150601 0.062152 0.075980\n",
"CI Upper (0.975) 0.241546 0.099950 0.154805"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = LinearMediation()\n",
"m.fit(X, W, y)\n",
"m.bootstrap()\n",
"m.summary()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment