Last active
December 8, 2025 02:51
-
-
Save tomicapretto/e6dd0a28d9ae328c4e4abe88b40ced4f to your computer and use it in GitHub Desktop.
R2D2 for GLMMs in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import polars as pl | |
| from scipy import stats | |
| from scipy.optimize import minimize, minimize_scalar | |
| from scipy import special | |
| from abc import ABC, abstractmethod | |
| def gbp_pdf(x, a, b, c, d): | |
| log_num = ( | |
| np.log(c) | |
| + (a * c - 1) * (np.log(x)- np.log(d)) | |
| - (a + b) * np.log1p((x / d) ** c) | |
| ) | |
| log_den = np.log(d) + special.betaln(a, b) | |
| return np.exp(log_num - log_den) | |
| class Family(ABC): | |
| @abstractmethod | |
| def W_to_R2(self, w): | |
| """Convert W values to R^2 | |
| Details are specific to each model family. To be implemented by subclasses. | |
| """ | |
| pass | |
| @abstractmethod | |
| def pdf(self, w): | |
| """Probability Density Function (PDF) of W. | |
| Details are specific to each model family. To be implemented by subclasses. | |
| """ | |
| pass | |
| def cdf(self, w): | |
| """Cumulative Distribution Function (CDF) of W | |
| Computes values of the CDF of W induced by a Beta(a, b) prior on R-squared. | |
| The strategy is to convert W to R-squared and find the CDF of the corresponding R-squared. | |
| """ | |
| return stats.beta.cdf(self.W_to_R2(w), a=self.a, b=self.b) | |
| def ppf(self, p, bounds=None): | |
| """Quantile Function of W | |
| Computes the quantiles of W induced by a Beta(a, b) prior on R2. | |
| It numerically inverts the CDF of W. | |
| NOTE: It's quite sensitive to `bounds`. | |
| """ | |
| if bounds is None: | |
| # Helper to find suitable bonds | |
| ub_candidates = [1, 10, 100, 1_000, 10_000, 100_000, 500_000, 1_000_000] | |
| for candidate in ub_candidates: | |
| if self.cdf(candidate) > 0.99: | |
| break | |
| bounds = (1 / 100_000, candidate) | |
| log_bounds = (np.log(bounds[0]), np.log(bounds[1])) | |
| def ppf_scalar(p): | |
| def distance(logw): | |
| return (self.cdf(w=np.exp(logw)) - p) ** 2 | |
| result = minimize_scalar(distance, bounds=log_bounds) | |
| output = result.x | |
| return np.exp(output) | |
| return np.array([ppf_scalar(p_i) for p_i in np.atleast_1d(p)]) | |
| class PoissonFamily(Family): | |
| def __init__(self, a, b, intercept): | |
| self.a = a | |
| self.b = b | |
| self.intercept = intercept | |
| def W_to_R2(self, w): | |
| return np.expm1(w) / (np.expm1(w) + np.exp(-self.intercept -0.5 * w)) | |
| def pdf(self, w): | |
| a, b, alpha = self.a, self.b, self.intercept | |
| f1 = 1 / special.beta(a, b) | |
| f2_num = np.expm1(w) ** (a - 1) * np.exp(-b * (alpha + w / 2)) * (3 * np.exp(w) - 1) | |
| f2_den = 2 * (np.expm1(w) + np.exp(-alpha - w / 2)) ** (a + b) | |
| return f1 * (f2_num / f2_den) | |
| class NegativeBinomialFamily(Family): | |
| def __init__(self, a, b, intercept, theta): | |
| self.a = a | |
| self.b = b | |
| self.intercept = intercept | |
| self.theta = theta | |
| def W_to_R2(self, w): | |
| return np.expm1(w) / (np.expm1(w) + self.theta * np.exp(-self.intercept -0.5 * w)) | |
| def pdf(self, w): | |
| a, b, alpha, theta = self.a, self.b, self.intercept, self.theta | |
| f1 = theta ** b / special.beta(a, b) | |
| f2_num = np.expm1(w) ** (a - 1) * np.exp(-b * (alpha + w / 2)) * (3 * np.exp(w) - 1) | |
| f2_den = 2 * (np.expm1(w) + theta * np.exp(-alpha - w / 2)) ** (a + b) | |
| return f1 * (f2_num / f2_den) | |
| class GaussianFamily(Family): | |
| def __init__(self, a, b, intercept, sigma): | |
| self.a = a | |
| self.b = b | |
| self.intercept = intercept | |
| self.sigma = sigma | |
| def W_to_R2(self, w): | |
| return w / (w + self.sigma ** 2) | |
| def pdf(self, w): | |
| return gbp_pdf(w, self.a, self.b, 1, self.sigma ** 2) | |
| class LogisticFamily(Family): | |
| def __init__(self, a, b, intercept): | |
| self.a = a | |
| self.b = b | |
| self.intercept = intercept | |
| def _W_to_R2_scalar(self, w, K=1000): | |
| # Equation 13 in Yanchenko et al. (2024) | |
| # We are estimating the integrals via Quasi-Monte Carlo (QMC) integration | |
| p_grid = np.linspace(1, K - 1, num=K - 1) / K | |
| eta = stats.norm(loc=self.intercept, scale=w ** 0.5).ppf(p_grid) | |
| mu = self._mean(eta) | |
| mu_1 = np.mean(mu).item() | |
| mu_2 = np.mean(mu ** 2).item() | |
| sigma_squared = np.mean(self._var(eta)).item() | |
| # Usage of M and V comes from Equation 3 | |
| M = mu_2 - mu_1 ** 2 # E(mu^2) - E(mu)^2 | |
| V = sigma_squared | |
| return M / (M + V) | |
| def W_to_R2(self, w): | |
| # Vectorized version of Equation 13 in Yanchenko et al. (2024) | |
| return np.array([self._W_to_R2_scalar(w) for w in np.atleast_1d(w)]) | |
| def pdf(self, w): | |
| # Density function of W | |
| # Computes values of the density function of W induced by a Beta(a, b) prior on R2. | |
| # The computation uses a numeric approxiation to the derivative of the CDF. | |
| # NOTE: When `w[i] - delta` is < 0 for some index `i`, we'll get a warning and NaN. | |
| # The following mechanism tries to overcome that. | |
| delta_candidates = [0.001, 0.0001, 0.00001] | |
| for delta in delta_candidates: | |
| if bool(np.all((w - delta) > 0)): | |
| break | |
| diff = self.cdf(w=w + delta) - self.cdf(w=w - delta) | |
| return diff / (2 * delta) | |
| def _mean(self, eta): | |
| return special.expit(eta) | |
| def _var(self, eta): | |
| mean = self._mean(eta) | |
| return mean * (1 - mean) | |
| def penalized_divergence(p_true, p_approx, params_current, params_reference, lam=0.25): | |
| # Penalized Pearson Chi-squared divergence | |
| integral = np.sum((1 - p_approx / p_true) ** 2) | |
| penalty = lam * np.sum((params_current - params_reference) ** 2) | |
| return integral + penalty | |
| def WGBP(family, lam=0.25, x0=np.ones(4), method="Powell"): | |
| """Compute parameters for the GBP Approximation | |
| This function finds the closest Generalized Beta Prime (GBP) distribution | |
| to the true pdf of W as measured by the Pearson Chi-squared divergence. | |
| """ | |
| a, b = family.a, family.b | |
| # Quantiles | |
| p = np.linspace(0.01, 0.99, num=500) | |
| # Values of 'w' where p_true and p_gbp are evaluated. | |
| w = family.ppf(p=p).flatten() | |
| p_true = family.pdf(w=w).flatten() | |
| # Copied from the R implementation because sometimes p_true has NaNs | |
| w = w[~np.isnan(p_true)] | |
| p_true = p_true[~np.isnan(p_true)] | |
| params_reference = np.array([a, b, 1, 1]) | |
| def divergence(log_params): | |
| params = np.exp(log_params) | |
| return penalized_divergence( | |
| p_true=p_true, | |
| p_approx=gbp_pdf(w, *params), | |
| params_current=params, | |
| params_reference=params_reference, | |
| lam=lam | |
| ) | |
| result = minimize(divergence, x0=np.log(x0), method=method) | |
| if result.success: | |
| return np.exp(result.x) | |
| raise Exception("Minimization didn't converge") | |
| def plot_w_approximations(rows_dict, family, **family_kwargs): | |
| def format_params(params): | |
| names = ["a", "b", "c", "d"] | |
| return "$" + ", ".join(f"{n}^*={p:.2f}" for n, p in zip(names, params)) + "$" | |
| # NOTE: The divergence is _very_ sensible to the range of values for 'w'. | |
| # When the true density is near zero, it goes up a lot. | |
| for row in rows_dict: | |
| # Get parameters | |
| a, b, intercept = row["a"], row["b"], row["beta_0"] | |
| params_paper = [row[name] for name in ["a_star", "b_star", "c_star", "d_star"]] | |
| family_name = "_".join(row["family"].split()) | |
| # Compute approximations | |
| family_obj = family(a=a, b=b, intercept=intercept, **family_kwargs) | |
| w_lower = max(family_obj.ppf(0.01).item(), 0.001) # Patch for small lower bounds | |
| w_upper = min(family_obj.ppf(0.99).item(), 250) # Patch for large upper bounds | |
| w = np.linspace(w_lower, w_upper, num=500) | |
| params_own = WGBP(family_obj) | |
| pdf_own = gbp_pdf(w, *params_own) | |
| pdf_paper = gbp_pdf(w, *params_paper) | |
| w_pdf = family_obj.pdf(w=w) | |
| divergence_own = penalized_divergence( | |
| p_true=w_pdf, | |
| p_approx=pdf_own, | |
| params_current=params_own, | |
| params_reference=np.array([a, b, 1, 1]) | |
| ) | |
| divergence_paper = penalized_divergence( | |
| p_true=w_pdf, | |
| p_approx=pdf_paper, | |
| params_current=params_paper, | |
| params_reference=np.array([a, b, 1, 1]) | |
| ) | |
| # Plot | |
| fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True) | |
| ax.plot(w, pdf_paper, label="Paper") | |
| ax.plot(w, pdf_own, label="Own") | |
| ax.plot(w, w_pdf, color="0.3", ls="--", label="True") | |
| title_left = "\n".join( | |
| [ | |
| "$\\bf{Approximations}$", | |
| format_params(params_paper) + " [paper]", | |
| format_params(params_own) + " [own]", | |
| "$\\bf{Divergences}$", | |
| f"{divergence_paper:.2f} [paper] vs. " + f"{divergence_own:.2f} [own]", | |
| ] | |
| ) | |
| title_right = "\n".join( | |
| [ | |
| f"$R^2 \\sim$ Beta({a}, {b})", | |
| f"$\\alpha = {intercept}$" | |
| ] | |
| ) | |
| ax.text(x=0, y=1.025, s=title_left, ha="left", size=11, transform=ax.transAxes) | |
| ax.text(x=1, y=1.025, s=title_right, ha="right", size=13, transform=ax.transAxes) | |
| ax.set(xlabel="W", yticks=[]) | |
| ax.legend(loc="upper right") | |
| file_name = f"imgs/{family_name}/a-{a}_b-{b}_intercept-{intercept}.png" | |
| fig.savefig(file_name) | |
| plt.close(fig) | |
| return None | |
| if __name__ == "__main__": | |
| df = pl.read_csv("data/paper_approximations.csv") | |
| families = { | |
| "poisson": {"class": PoissonFamily, "kwargs": {}}, | |
| "nb": {"class": NegativeBinomialFamily, "kwargs": {"theta": 2}}, | |
| "logistic": {"class": LogisticFamily, "kwargs": {}}, | |
| } | |
| for family_name, family_dict in families.items(): | |
| for intercept in [-2, 0, 2]: | |
| print(f"Family: {family_name}, Intercept: {intercept}") | |
| plot_w_approximations( | |
| df.filter( | |
| pl.col("family") == family_name, | |
| pl.col("beta_0") == intercept | |
| ).to_dicts(), | |
| family_dict["class"], | |
| **family_dict["kwargs"] | |
| ); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| beta_0 | a | b | a_star | b_star | c_star | d_star | family | |
|---|---|---|---|---|---|---|---|---|
| -2 | 0.5 | 0.5 | 0.19 | 0.77 | 4.22 | 3.17 | poisson | |
| -2 | 0.5 | 0.5 | 0.48 | 0.22 | 1.23 | 1.78 | logistic | |
| -2 | 0.5 | 0.5 | 0.21 | 0.74 | 4.78 | 3.49 | nb | |
| -2 | 1 | 1 | 0.42 | 1.50 | 3.75 | 2.56 | poisson | |
| -2 | 1 | 1 | 1.45 | 0.51 | 0.99 | 1.74 | logistic | |
| -2 | 1 | 1 | 0.44 | 1.46 | 4.31 | 2.93 | nb | |
| -2 | 1 | 4 | 0.36 | 4.29 | 3.32 | 1.98 | poisson | |
| -2 | 1 | 4 | 0.99 | 1.72 | 1.19 | 2.53 | logistic | |
| -2 | 1 | 4 | 0.36 | 4.98 | 3.95 | 2.51 | nb | |
| -2 | 4 | 1 | 2.81 | 2.61 | 2.84 | 2.43 | poisson | |
| -2 | 4 | 1 | 8.21 | 0.65 | 0.74 | 1.49 | logistic | |
| -2 | 4 | 1 | 2.50 | 2.04 | 3.65 | 2.76 | nb | |
| -2 | 4 | 4 | 2.00 | 6.38 | 3.14 | 2.25 | poisson | |
| -2 | 4 | 4 | 8.15 | 2.18 | 0.88 | 1.57 | logistic | |
| -2 | 4 | 4 | 3.50 | 6.99 | 2.95 | 2.45 | nb | |
| 0 | 0.5 | 0.5 | 0.23 | 0.96 | 2.31 | 2.03 | poisson | |
| 0 | 0.5 | 0.5 | 0.72 | 0.39 | 0.85 | 1.31 | logistic | |
| 0 | 0.5 | 0.5 | 0.20 | 0.87 | 2.98 | 2.47 | nb | |
| 0 | 1 | 1 | 0.50 | 1.83 | 2.00 | 1.45 | poisson | |
| 0 | 1 | 1 | 1.47 | 0.67 | 0.77 | 1.68 | logistic | |
| 0 | 1 | 1 | 0.44 | 1.67 | 2.60 | 1.84 | nb | |
| 0 | 1 | 4 | 0.63 | 5.49 | 1.52 | 0.95 | poisson | |
| 0 | 1 | 4 | 1.17 | 2.12 | 0.89 | 2.03 | logistic | |
| 0 | 1 | 4 | 0.50 | 4.85 | 2.00 | 1.27 | nb | |
| 0 | 4 | 1 | 2.08 | 2.68 | 1.92 | 1.53 | poisson | |
| 0 | 4 | 1 | 7.72 | 0.72 | 0.68 | 1.44 | logistic | |
| 0 | 4 | 1 | 2.24 | 2.65 | 2.24 | 1.85 | nb | |
| 0 | 4 | 4 | 2.10 | 6.65 | 1.83 | 1.12 | poisson | |
| 0 | 4 | 4 | 7.37 | 2.79 | 0.72 | 1.65 | logistic | |
| 0 | 4 | 4 | 1.83 | 6.65 | 2.28 | 1.57 | nb | |
| 2 | 0.5 | 0.5 | 0.49 | 1.38 | 0.93 | 0.70 | poisson | |
| 2 | 0.5 | 0.5 | 0.48 | 0.22 | 1.23 | 1.78 | logistic | |
| 2 | 0.5 | 0.5 | 0.37 | 1.19 | 1.26 | 1.11 | nb | |
| 2 | 1 | 1 | 0.99 | 2.33 | 0.94 | 0.38 | poisson | |
| 2 | 1 | 1 | 1.45 | 0.51 | 0.99 | 1.74 | logistic | |
| 2 | 1 | 1 | 0.80 | 2.27 | 1.16 | 0.71 | nb | |
| 2 | 1 | 4 | 1.14 | 1.98 | 0.89 | 0.44 | poisson | |
| 2 | 1 | 4 | 0.99 | 1.72 | 1.19 | 2.53 | logistic | |
| 2 | 1 | 4 | 0.96 | 8.19 | 1.03 | 0.55 | nb | |
| 2 | 4 | 1 | 2.38 | 2.77 | 1.11 | 0.53 | poisson | |
| 2 | 4 | 1 | 8.21 | 0.65 | 0.74 | 1.49 | logistic | |
| 2 | 4 | 1 | 2.16 | 2.78 | 1.35 | 0.87 | nb | |
| 2 | 4 | 4 | 3.66 | 9.86 | 0.97 | 0.36 | poisson | |
| 2 | 4 | 4 | 8.15 | 2.18 | 0.88 | 1.57 | logistic | |
| 2 | 4 | 4 | 2.92 | 6.44 | 1.24 | 0.43 | nb |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment