Last active
June 29, 2017 10:40
-
-
Save yoavram/c6a4faa9abdaeb2daeba to your computer and use it in GitHub Desktop.
Truncated normal distribution in Python. Translation fron CRAN/truncnorm (https://github.com/cran/truncnorm) C code to python. The examples in the __main__ are comparisons to R.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.stats import norm | |
EPSILON = 1e-10 | |
DBL_MAX = 1.79769e+308 | |
M_1_SQRT_2PI = 1.0/np.sqrt(2.0 * np.pi) | |
t1 = 0.15 | |
t2 = 2.18 | |
t3 = 0.725 | |
t4 = 0.45 | |
# R random functions | |
def runif(a, b): | |
return (b - a) * np.random.random_sample() + a | |
rnorm = np.random.normal | |
rexp = np.random.exponential | |
def truncnorm_zeroin( #/* An estimate of the root */ | |
ax, #/* Left border | of the range */ | |
bx, #/* Right border| the root is seeked*/ | |
fa, fb, #/* f(a), f(b) */ | |
f, #/* Function under investigation */ | |
info, #/* Add'l info passed on to f */ | |
Tol, #/* Acceptable tolerance */ | |
Maxit): #/* Max # of iterations */ | |
a = ax | |
b = bx | |
c = a | |
fc = fa | |
maxit = Maxit + 1 | |
tol = Tol | |
#/* First test if we have found a root at an endpoint */ | |
if fa == 0.0: | |
Tol = 0.0 | |
Maxit = 0 | |
return a | |
if fb == 0.0: | |
Tol = 0.0 | |
Maxit = 0 | |
return b | |
while maxit: #/* Main iteration loop */ | |
maxit -= 1 | |
prev_step = b-a #/* Distance from the last but one to the last approximation */ | |
if abs(fc) < abs(fb): | |
# /* Swap data for b to be the */ | |
a = b | |
b = c | |
c = a #/* best approximation */ | |
fa=fb | |
fb=fc | |
fc=fa | |
tol_act = 2 * EPSILON * abs(b) + tol/2 #/* Actual tolerance */ | |
new_step = (c-b)/2 # /* Step at this iteration */ | |
if abs(new_step) <= tol_act or fb == 0: | |
Maxit -= maxit | |
Tol = abs(c-b) | |
return b #/* Acceptable approx. is found */ | |
#/* Decide if the interpolation can be tried */ | |
if abs(prev_step) >= tol_act and abs(fa) > abs(fb): | |
#/* If prev_step was large enough and was in true direction, Interpolation may be tried */ | |
cb = c-b | |
if a == c: | |
# /* If we have only two distinct */ | |
# /* points linear interpolation */ | |
t1 = fb/fa #/* can only be applied */ | |
p = cb*t1 #/* Interpolation step is calcu- */ | |
q = 1.0 - t1 # /* lated in the form p/q division operations is delayed until the last moment */ | |
else: #/* Quadric inverse interpolation*/ | |
q = fa/fc | |
t1 = fb/fc | |
t2 = fb/fa | |
p = t2 * ( cb*q*(q-t1) - (b-a)*(t1-1.0) ) | |
q = (q-1.0) * (t1-1.0) * (t2-1.0) | |
if p>0: #/* p was calculated with the */ | |
q = -q #/* opposite sign make p positive */ | |
else: #/* and assign possible minus to */ | |
p = -p # /* q */ | |
if p < (0.75*cb*q - abs(tol_act*q)/2) and p < abs(prev_step*q/2): | |
#/* If b+p/q falls in [b,c]*/ and isn't too large */ | |
new_step = p/q | |
# /* it is accepted | |
# * If p/q is too large then the | |
# * bisection procedure can | |
# * reduce [b,c] range to more | |
# * extent */ | |
if abs(new_step) < tol_act: # /* Adjust the step to be not less than tolerance*/ | |
if new_step > 0: | |
new_step = tol_act | |
else: | |
new_step = -tol_act | |
a = b | |
fa = fb # /* Save the previous approx. */ | |
b += new_step | |
fb = f(b, info) #/* Do step to a new approxim. */ | |
if (fb > 0 and fc > 0) or (fb < 0 and fc < 0): | |
#/* Adjust c for it to have a sign opposite to that of b */ | |
c = a | |
fc = fa | |
#/* failed! */ | |
Tol = abs(c-b) | |
Maxit = -1 | |
return b | |
def _ptruncnorm(q, a, b, mean, sd): | |
if (q < a): | |
return 0.0 | |
elif (q > b): | |
return 1.0 | |
else: | |
rv = norm(mean, sd) | |
c1 = rv.cdf(q) | |
c2 = rv.cdf(a) | |
c3 = rv.cdf(b) | |
return (c1 - c2) / (c3 - c2) | |
ptruncnorm = np.vectorize(_ptruncnorm) | |
def qtruncnorm(p, a, b, mean, sd): | |
n_p = len(p) if isinstance(p, np.ndarray) else 1 | |
n_a = len(a) if isinstance(a, np.ndarray) else 1 | |
n_b = len(b) if isinstance(b, np.ndarray) else 1 | |
n_mean = len(mean) if isinstance(mean, np.ndarray) else 1 | |
n_sd = len(sd) if isinstance(sd, np.ndarray) else 1 | |
n = max(max(max(n_p, n_a), max(n_b, n_mean)), n_sd) | |
ret = np.zeros(n) | |
for i in range(n): | |
cp = p[i % n_p] | |
ca = a[i % n_a] if isinstance(a, np.ndarray) else a | |
cb = b[i % n_b] if isinstance(b, np.ndarray) else b | |
cmean = mean[i % n_mean] if isinstance(mean, np.ndarray) else mean | |
csd = sd[i % n_sd] if isinstance(sd, np.ndarray) else sd | |
if cp == 0.0: | |
ret[i] = ca | |
elif cp == 1.0: | |
ret[i] = cb | |
elif (cp < 0.0 or cp > 1.0): | |
ret[i] = np.nan | |
elif (ca == -np.inf and cb == np.inf): | |
ret[i] = norm(cmean, csd).ppf(cp) #, TRUE, FALSE) | |
else: | |
# /* We need to possible adjust ca and cb for R_zeroin(), | |
# * because R_zeroin() requires finite bounds and ca or cb (but | |
# * not both, see above) may be infinite. In that case, we use | |
# * a simple stepping out procedure to find a lower or upper | |
# * bound from which to begin the search. | |
# */ | |
lower = ca | |
upper = cb | |
if lower == -np.inf: | |
lower = -1 | |
while ptruncnorm(lower, ca, cb, cmean, csd) - cp >= 0: | |
lower *= 2.0 | |
elif upper == np.inf: | |
upper = 1 | |
while ptruncnorm(upper, ca, cb, cmean, csd) - cp <= 0: | |
upper *= 2.0 | |
t = dict() | |
t['a'] = ca | |
t['b'] = cb | |
t['mean'] = cmean | |
t['sd'] = csd | |
t['p'] = cp | |
maxit = 200 | |
tol = 0.0 #/* Set tolerance! */ | |
ret[i] = truncnorm_zeroin(lower, upper, qtmin(lower, t), qtmin(upper, t), qtmin, t, tol, maxit) | |
return ret | |
# /* qtmin - helper function to calculate quantiles of the truncated | |
# * normal distribution. | |
# * | |
# * The root of this function is the desired quantile, given that *p | |
# * defines a truncated normal distribution and the desired | |
# * quantile. This function increases monotonically in x and is | |
# * positive for x=a and negative for x=b if 0 < p < 1. | |
# */ | |
def qtmin(x, t): | |
return ptruncnorm(x, t['a'], t['b'], t['mean'], t['sd']) - t['p'] | |
def rtruncnorm(n, a, b, mean, sd): | |
if np.nan == n: | |
raise ValueError("n is NA - aborting."); | |
n_a = len(a) if isinstance(a, np.ndarray) else 1 | |
n_b = len(b) if isinstance(b, np.ndarray) else 1 | |
n_mean = len(mean) if isinstance(mean, np.ndarray) else 1 | |
n_sd = len(sd) if isinstance(sd, np.ndarray) else 1 | |
nn = max(n, max(max(n_a, n_b), max(n_mean, n_sd))); | |
ret = np.zeros(n) | |
# GetRNGstate(); | |
for i in range(nn): | |
ca = a[i % n_a] if isinstance(a, np.ndarray) else a | |
cb = b[i % n_b] if isinstance(b, np.ndarray) else b | |
cmean = mean[i % n_mean] if isinstance(mean, np.ndarray) else mean | |
csd = sd[i % n_sd] if isinstance(sd, np.ndarray) else sd | |
if (np.isfinite(ca) and np.isfinite(cb)): | |
ret[i] = r_truncnorm(ca, cb, cmean, csd) | |
elif (-np.inf == ca and np.isfinite(cb)): | |
ret[i] = r_righttruncnorm(cb, cmean, csd) | |
elif (np.isfinite(ca) and np.inf == cb): | |
ret[i] = r_lefttruncnorm(ca, cmean, csd) | |
elif (-np.inf == ca and np.inf == cb): | |
ret[i] = rnorm(cmean, csd) | |
else: | |
ret[i] = np.nan; | |
return ret; | |
def r_truncnorm(a, b, mean, sd): | |
alpha = (a - mean) / sd; | |
beta = (b - mean) / sd; | |
rv = norm(0.0, 1.0) | |
phi_a = rv.pdf(alpha) | |
phi_b = rv.pdf(beta) | |
if (beta <= alpha): | |
return NA_REAL; | |
elif (alpha <= 0 and 0 <= beta): # /* 2 */ | |
if (phi_a <= t1 or phi_b <= t1): # /* 2 (a) */ | |
return mean + sd * nrs_a_b(alpha, beta) | |
else: # /* 2 (b) */ | |
return mean + sd * urs_a_b(alpha, beta) | |
elif (alpha > 0):# { /* 3 */ | |
if (phi_a / phi_b <= t2):# /* 3 (a) */ | |
return mean + sd * urs_a_b(alpha, beta) | |
else: | |
if (alpha < t3): # /* 3 (b) */ | |
return mean + sd * hnrs_a_b(alpha, beta) | |
else: # { /* 3 (c) */ | |
return mean + sd * ers_a_b(alpha, beta) | |
else: # /* 3s */ | |
if (phi_b / phi_a <= t2): # /* 3s (a) */ | |
return mean - sd * urs_a_b(-beta, -alpha) | |
else: | |
if (beta > -t3): # /* 3s (b) */ | |
return mean - sd * hnrs_a_b(-beta, -alpha) | |
else: # /* 3s (c) */ | |
return mean - sd * ers_a_b(-beta, -alpha) | |
#/* Normal rejection sampling (a,b) */ | |
def nrs_a_b(a, b): | |
x = -DBL_MAX | |
while(x < a or x > b): | |
x = rnorm(0, 1) | |
return x | |
#/* Uniform rejection sampling */ | |
def urs_a_b(a, b): | |
rv = norm(0, 1) | |
phi_a = rv.pdf(a) | |
x = 0.0 | |
u = 0.0 | |
#/* Upper bound of normal density on [a, b] */ | |
ub = a < 0 and b > 0 if M_1_SQRT_2PI else phi_a | |
x = runif(a, b) | |
while runif(0, 1) * ub > rv.pdf(x): | |
x = runif(a, b) | |
return x | |
#/* Half-normal rejection sampling */ | |
def hnrs_a_b(a, b): | |
x = a - 1.0 | |
while x < a or x > b: | |
x = rnorm(0, 1) | |
x = abs(x) | |
return x | |
#/* Exponential rejection sampling (a,b) */ | |
def ers_a_b(a, b): | |
ainv = 1.0 / a | |
x = rexp(ainv) + a # /* rexp works with 1/lambda */ | |
rho = np.exp(-0.5 * np.pow((x-a), 2)) | |
while (runif(0, 1) > rho or x > b): | |
x = rexp(ainv) + a # /* rexp works with 1/lambda */ | |
rho = np.exp(-0.5 * np.pow((x-a), 2)) | |
return x | |
if __name__ == '__main__': | |
seq = np.arange # R->Py | |
print('ptruncnorm') | |
y = seq(0, 1, 0.1) | |
print(y) | |
print(ptruncnorm(y, 0, 1, 0.5, 0.1).tolist()) | |
print([0.000000e+00,3.138461e-05, 1.349612e-03, 2.274986e-02, 1.586551e-01, 5.000000e-01, 8.413449e-01, 9.772501e-01, 9.986504e-01, 9.999686e-01, 1.000000e+00]) | |
y = seq(0, 10, 1) | |
print(y) | |
print(ptruncnorm(y, 3, 8, 5, 0.3).tolist()) | |
print([0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0004290603, 0.5000000000, 0.9995709397, 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000]) | |
print('qtruncnorm') | |
x = seq(0,1,0.1) | |
print(x) | |
print(qtruncnorm(x, 0, 2, 1, 0.2).tolist()) | |
print([0.0000000, 0.7436899, 0.8316759, 0.8951200, 0.9493306, 1.0000000, 1.0506694, 1.1048800, 1.1683241, 1.2563101, 2.0000000]) | |
x = rtruncnorm(10000, 1, 5, 2.5, 1) | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
sns.distplot(x) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment