Created
December 7, 2022 22:19
-
-
Save arnos-stuff/e599e52aea6f1fef7306fd7c74ff1907 to your computer and use it in GitHub Desktop.
Inferring lognormal error distributions from odds ratios and bayes rule
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple | |
from numba import jit, float64 | |
import sympy as sp | |
import numpy as np | |
import scipy.optimize as opt | |
def from_confidence_intervals(lower: float, upper: float, symbolic=False) -> (float, float): | |
"""Compute the mean and standard deviation of a lognormal distribution from the | |
confidence interval of the mean. | |
:param mu: the mean of the lognormal distribution | |
:param lower: the lower bound of the confidence interval | |
:param upper: the upper bound of the confidence interval | |
:return: the standard deviation of the lognormal distribution | |
""" | |
if symbolic: | |
log = sp.log | |
sqrt = sp.sqrt | |
else: | |
log = np.log | |
sqrt = np.sqrt | |
mean = (log(lower) + log(upper)) / 2.0 | |
std = sqrt(log(upper) - log(lower) )/ (2.0 * 1.96) | |
return mean, std | |
def prod_lognormals_( | |
mu1: float, sigma1: float, | |
mu2: float, sigma2: float, | |
symbolic=False) -> (float, float): | |
"""Compute the mean and standard deviation of the product of two lognormal distributions. | |
:param mu1: the mean of the first lognormal distribution | |
:param sigma1: the standard deviation of the first lognormal distribution | |
:param mu2: the mean of the second lognormal distribution | |
:param sigma2: the standard deviation of the second lognormal distribution | |
:return: the mean and standard deviation of the sum of the two lognormal distributions | |
""" | |
if symbolic: | |
sqrt = sp.sqrt | |
else: | |
sqrt = np.sqrt | |
mu = mu1 + mu2 | |
sigma = sp.sqrt(sigma1**2 + sigma2**2) | |
return mu, sigma | |
def sum_lognormals_( | |
mu1: float, sigma1: float, | |
mu2: float, sigma2: float, | |
symbolic = False) -> (float, float): | |
"""Compute the mean and standard deviation of the sum of two lognormal distributions. | |
:param mu1: the mean of the first lognormal distribution | |
:param sigma1: the standard deviation of the first lognormal distribution | |
:param mu2: the mean of the second lognormal distribution | |
:param sigma2: the standard deviation of the second lognormal distribution | |
:return: the mean and standard deviation of the sum of the two lognormal distributions | |
""" | |
if symbolic: | |
exp = sp.exp | |
log = sp.log | |
sqrt = sp.sqrt | |
else: | |
exp = np.exp | |
log = np.log | |
sqrt = np.sqrt | |
denom = exp(mu1 + 0.5 * sigma1**2) + exp(mu2 + 0.5 * sigma2**2) | |
var = ( | |
exp(2 * mu1 + sigma1**2) * (exp(sigma1**2) - 1) + | |
exp(2 * mu2 + sigma2**2) * (exp(sigma2**2) - 1) + 1 | |
) / denom | |
mu = log(denom) - 0.5 * var | |
return mu, sqrt(var) | |
def sum_lognormals(*params: List[Tuple[float]]) -> (float, float): | |
"""Compute the mean and standard deviation of the sum of multiple lognormal distributions. | |
:param params: a list of tuples of the form (mu, sigma) for each lognormal distribution | |
:return: the mean and standard deviation of the sum of the lognormal distributions | |
""" | |
params = list(params) | |
mu, sigma = params.pop(0) | |
for mu_i, sigma_i in params: | |
mu, sigma = sum_lognormals_(mu, sigma, mu_i, sigma_i) | |
return mu, sigma | |
def prod_lognormals(*params: List[Tuple[float]]) -> (float, float): | |
"""Compute the mean and standard deviation of the product of multiple lognormal distributions. | |
""" | |
params = list(params) | |
mu, sigma = params.pop(0) | |
for mu_i, sigma_i in params: | |
mu, sigma = prod_lognormals_(mu, sigma, mu_i, sigma_i) | |
return mu, sigma | |
def symbolic_guess_odds_ratio( | |
) -> sp.Symbol: | |
"""Compute the odds ratio of lognormal distributions. | |
If both p & q are lognormally distributed, then the odds ratio is given by | |
p/q ~ lognormal(mu_or, sigma_or) | |
in the special that p = prob(outcome|condition) and q = prob(outcome|~condition) | |
we can infer the distribution of both p and q from the odds ratio. | |
this assumes we know both the marginal distribution p(condition) and p(y) | |
this is because p(y) = p(y|condition) * p(condition) + p(y|~condition) * p(~condition) | |
which yields the following system of equations: | |
lognormal(mu_outcome, sigma_outcome) = | |
lognormal(mu_q, sigma_q) * lognormal(mu_or, sigma_or) * lognormal(mu_condition, sigma_condition) + | |
lognormal(mu_q, sigma_q) - lognormal(mu_q, sigma_q) * lognormal(mu_condition, sigma_condition) | |
which simplifies to | |
sum_lognormals(outcome, product_lognormals(q, condition)) = | |
sum_lognormals(q, product_lognormals(q, or, condition)) | |
:param mu: the mean of the lognormal distribution | |
:param sigma: the standard deviation of the lognormal distribution | |
:return: the odds ratio of the lognormal distribution | |
""" | |
mu_q, mu_outcome, mu_condition = sp.symbols('mu_q, mu_outcome, mu_condition') | |
sigma_q, sigma_outcome, sigma_condition = sp.symbols('sigma_q, sigma_outcome, sigma_condition', positive=True) | |
params_q = (mu_q, sigma_q) | |
params_outcome = (mu_outcome, sigma_outcome) | |
params_condition = (mu_condition, sigma_condition) | |
mu_lhs, sigma_lhs = sum_lognormals(params_outcome, prod_lognormals(params_q, params_condition)) | |
mu_rhs, sigma_rhs = sum_lognormals(params_q, prod_lognormals(params_q, params_or, params_condition)) | |
return sp.simplify(mu_lhs - mu_rhs), sp.simplify(sigma_lhs - sigma_rhs) | |
@jit((float64, float64, float64, float64, float64, float64), nopython=True) | |
def raw_guess_odds_ratio( | |
mu_q: float, sigma_q: float, | |
mu_outcome: float, sigma_outcome: float, | |
mu_condition: float, sigma_condition: float, | |
) -> (float, float): | |
"""Guess the odds ratio of a lognormal distribution.""" | |
params_q = (mu_q, sigma_q) | |
params_outcome = (mu_outcome, sigma_outcome) | |
params_condition = (mu_condition, sigma_condition) | |
mu_expr = ( | |
( | |
np.exp(mu_outcome + 0.5*sigma_outcome**2) + | |
np.exp(mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
)*( | |
np.exp(mu_q + 0.5*sigma_q**2) + 5.56635651637762*np.exp( | |
mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
)*(-np.log( | |
( | |
5.56635651637762*np.exp(mu_condition + 0.5*sigma_condition**2) + 1.0 | |
)*np.exp(mu_q + 0.5*sigma_q**2)) + np.log( | |
np.exp(mu_outcome + 0.5*sigma_outcome**2) + np.exp( | |
mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2 | |
) | |
) | |
) + ( | |
np.exp(mu_outcome + 0.5*sigma_outcome**2) + | |
np.exp(mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
)*( | |
-0.5*(1 - np.exp(sigma_q**2))*np.exp(2*mu_q + sigma_q**2) + | |
( | |
16.0819550303169*np.exp(sigma_condition**2 + sigma_q**2) - 15.4921624337098 | |
)*np.exp(2*mu_condition + 2*mu_q + sigma_condition**2 + sigma_q**2) + 0.5 | |
) + 0.5*(np.exp(mu_q + 0.5*sigma_q**2) + 5.56635651637762*np.exp( | |
mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2 | |
) | |
)*( | |
(1 - np.exp(sigma_outcome**2))*np.exp(2*mu_outcome + sigma_outcome**2) + | |
(1 - np.exp(sigma_condition**2 + sigma_q**2))*np.exp(2*mu_condition + 2*mu_q + sigma_condition**2 + sigma_q**2) - 1) | |
)/( | |
( | |
np.exp(mu_outcome + 0.5*sigma_outcome**2) + | |
np.exp(mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
)*( | |
np.exp(mu_q + 0.5*sigma_q**2) + 5.56635651637762*np.exp(mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
) | |
) | |
sigma_expr = np.sqrt( | |
( | |
-(1 - np.exp(sigma_outcome**2))*np.exp(2*mu_outcome + sigma_outcome**2) - | |
(1 - np.exp(sigma_condition**2 + sigma_q**2))*np.exp( | |
2*mu_condition + 2*mu_q + sigma_condition**2 + sigma_q**2 | |
) + 1 | |
)/( | |
np.exp(mu_outcome + 0.5*sigma_outcome**2) + | |
np.exp(mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
) | |
) - np.sqrt( | |
( | |
-(1 - np.exp(sigma_q**2))*np.exp(2*mu_q + sigma_q**2) + | |
30.9843248674196*(1.03807038553403*np.exp(sigma_condition**2 + sigma_q**2) - 1)*np.exp( | |
2*mu_condition + 2*mu_q + sigma_condition**2 + sigma_q**2 | |
) + 1 | |
)/( | |
np.exp(mu_q + 0.5*sigma_q**2) + | |
5.56635651637762*np.exp(mu_condition + mu_q + 0.5*sigma_condition**2 + 0.5*sigma_q**2) | |
) | |
) | |
return mu_expr, sigma_expr | |
def solve_for_remaining_parameters( | |
mu_outcome: float, sigma_outcome: float, | |
mu_condition: float, sigma_condition: float, | |
mu0_q: float, sigma0_q: float, | |
): | |
"""Solve for the remaining parameters.""" | |
def objective(params_array: np.ndarray): | |
delta_mu, delta_sigma = raw_guess_odds_ratio( | |
params_array[0], params_array[1], | |
mu_outcome, sigma_outcome, | |
mu_condition, sigma_condition | |
) | |
return (delta_mu)**2 + (delta_sigma)**2 | |
result = opt.minimize( | |
fun=objective, | |
x0=np.array([mu0_q, sigma0_q]), | |
method='BFGS', | |
options={'disp': True, 'maxiter': 1000, 'gtol': 1e-3}, | |
# +inf > mu > -inf, +inf > sigma >= 0 | |
constraints=opt.LinearConstraint( | |
A=np.array([[1, 0], [0, 1]]), | |
lb=np.array([-np.inf, 0]), | |
ub=np.array([np.inf, np.inf]) | |
) | |
) | |
return result.x | |
if __name__ == '__main__': | |
conf_interval_autism = (0.01, 0.02) | |
conf_interval_gender = (0.004, 0.013) | |
conf_interval_or = (4.1, 7.28) | |
params_autism = from_confidence_intervals(*conf_interval_autism) | |
params_gender = from_confidence_intervals(*conf_interval_gender) | |
params_or = from_confidence_intervals(*conf_interval_or) | |
mu_autism, sigma_autism = params_autism | |
mu_gender, sigma_gender = params_gender | |
# print(sp.pycode(mu_expr), end='\n\n') | |
# print(sp.pycode(sigma_expr)) | |
mu0 = (mu_autism + mu_gender) / 2.0 | |
s0 = np.sqrt(sigma_autism**2 + sigma_gender**2) | |
# fwd = raw_guess_odds_ratio(mu0, s0, mu_autism, sigma_autism, mu_gender, sigma_gender) | |
result = solve_for_remaining_parameters(mu_autism, sigma_autism, mu_gender, sigma_gender, mu0, s0) | |
# print(fwd) | |
print(result) | |
# mu_q = -4.35836312, sigma_q =0.42120029 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment