Skip to content

Instantly share code, notes, and snippets.

@yoavram
Last active June 29, 2017 10:40
Show Gist options
  • Save yoavram/c6a4faa9abdaeb2daeba to your computer and use it in GitHub Desktop.
Save yoavram/c6a4faa9abdaeb2daeba to your computer and use it in GitHub Desktop.
Truncated normal distribution in Python. Translation fron CRAN/truncnorm (https://github.com/cran/truncnorm) C code to python. The examples in the __main__ are comparisons to R.
import numpy as np
from scipy.stats import norm
EPSILON = 1e-10
DBL_MAX = 1.79769e+308
M_1_SQRT_2PI = 1.0/np.sqrt(2.0 * np.pi)
t1 = 0.15
t2 = 2.18
t3 = 0.725
t4 = 0.45
# R random functions
def runif(a, b):
return (b - a) * np.random.random_sample() + a
rnorm = np.random.normal
rexp = np.random.exponential
def truncnorm_zeroin( #/* An estimate of the root */
ax, #/* Left border | of the range */
bx, #/* Right border| the root is seeked*/
fa, fb, #/* f(a), f(b) */
f, #/* Function under investigation */
info, #/* Add'l info passed on to f */
Tol, #/* Acceptable tolerance */
Maxit): #/* Max # of iterations */
a = ax
b = bx
c = a
fc = fa
maxit = Maxit + 1
tol = Tol
#/* First test if we have found a root at an endpoint */
if fa == 0.0:
Tol = 0.0
Maxit = 0
return a
if fb == 0.0:
Tol = 0.0
Maxit = 0
return b
while maxit: #/* Main iteration loop */
maxit -= 1
prev_step = b-a #/* Distance from the last but one to the last approximation */
if abs(fc) < abs(fb):
# /* Swap data for b to be the */
a = b
b = c
c = a #/* best approximation */
fa=fb
fb=fc
fc=fa
tol_act = 2 * EPSILON * abs(b) + tol/2 #/* Actual tolerance */
new_step = (c-b)/2 # /* Step at this iteration */
if abs(new_step) <= tol_act or fb == 0:
Maxit -= maxit
Tol = abs(c-b)
return b #/* Acceptable approx. is found */
#/* Decide if the interpolation can be tried */
if abs(prev_step) >= tol_act and abs(fa) > abs(fb):
#/* If prev_step was large enough and was in true direction, Interpolation may be tried */
cb = c-b
if a == c:
# /* If we have only two distinct */
# /* points linear interpolation */
t1 = fb/fa #/* can only be applied */
p = cb*t1 #/* Interpolation step is calcu- */
q = 1.0 - t1 # /* lated in the form p/q division operations is delayed until the last moment */
else: #/* Quadric inverse interpolation*/
q = fa/fc
t1 = fb/fc
t2 = fb/fa
p = t2 * ( cb*q*(q-t1) - (b-a)*(t1-1.0) )
q = (q-1.0) * (t1-1.0) * (t2-1.0)
if p>0: #/* p was calculated with the */
q = -q #/* opposite sign make p positive */
else: #/* and assign possible minus to */
p = -p # /* q */
if p < (0.75*cb*q - abs(tol_act*q)/2) and p < abs(prev_step*q/2):
#/* If b+p/q falls in [b,c]*/ and isn't too large */
new_step = p/q
# /* it is accepted
# * If p/q is too large then the
# * bisection procedure can
# * reduce [b,c] range to more
# * extent */
if abs(new_step) < tol_act: # /* Adjust the step to be not less than tolerance*/
if new_step > 0:
new_step = tol_act
else:
new_step = -tol_act
a = b
fa = fb # /* Save the previous approx. */
b += new_step
fb = f(b, info) #/* Do step to a new approxim. */
if (fb > 0 and fc > 0) or (fb < 0 and fc < 0):
#/* Adjust c for it to have a sign opposite to that of b */
c = a
fc = fa
#/* failed! */
Tol = abs(c-b)
Maxit = -1
return b
def _ptruncnorm(q, a, b, mean, sd):
if (q < a):
return 0.0
elif (q > b):
return 1.0
else:
rv = norm(mean, sd)
c1 = rv.cdf(q)
c2 = rv.cdf(a)
c3 = rv.cdf(b)
return (c1 - c2) / (c3 - c2)
ptruncnorm = np.vectorize(_ptruncnorm)
def qtruncnorm(p, a, b, mean, sd):
n_p = len(p) if isinstance(p, np.ndarray) else 1
n_a = len(a) if isinstance(a, np.ndarray) else 1
n_b = len(b) if isinstance(b, np.ndarray) else 1
n_mean = len(mean) if isinstance(mean, np.ndarray) else 1
n_sd = len(sd) if isinstance(sd, np.ndarray) else 1
n = max(max(max(n_p, n_a), max(n_b, n_mean)), n_sd)
ret = np.zeros(n)
for i in range(n):
cp = p[i % n_p]
ca = a[i % n_a] if isinstance(a, np.ndarray) else a
cb = b[i % n_b] if isinstance(b, np.ndarray) else b
cmean = mean[i % n_mean] if isinstance(mean, np.ndarray) else mean
csd = sd[i % n_sd] if isinstance(sd, np.ndarray) else sd
if cp == 0.0:
ret[i] = ca
elif cp == 1.0:
ret[i] = cb
elif (cp < 0.0 or cp > 1.0):
ret[i] = np.nan
elif (ca == -np.inf and cb == np.inf):
ret[i] = norm(cmean, csd).ppf(cp) #, TRUE, FALSE)
else:
# /* We need to possible adjust ca and cb for R_zeroin(),
# * because R_zeroin() requires finite bounds and ca or cb (but
# * not both, see above) may be infinite. In that case, we use
# * a simple stepping out procedure to find a lower or upper
# * bound from which to begin the search.
# */
lower = ca
upper = cb
if lower == -np.inf:
lower = -1
while ptruncnorm(lower, ca, cb, cmean, csd) - cp >= 0:
lower *= 2.0
elif upper == np.inf:
upper = 1
while ptruncnorm(upper, ca, cb, cmean, csd) - cp <= 0:
upper *= 2.0
t = dict()
t['a'] = ca
t['b'] = cb
t['mean'] = cmean
t['sd'] = csd
t['p'] = cp
maxit = 200
tol = 0.0 #/* Set tolerance! */
ret[i] = truncnorm_zeroin(lower, upper, qtmin(lower, t), qtmin(upper, t), qtmin, t, tol, maxit)
return ret
# /* qtmin - helper function to calculate quantiles of the truncated
# * normal distribution.
# *
# * The root of this function is the desired quantile, given that *p
# * defines a truncated normal distribution and the desired
# * quantile. This function increases monotonically in x and is
# * positive for x=a and negative for x=b if 0 < p < 1.
# */
def qtmin(x, t):
return ptruncnorm(x, t['a'], t['b'], t['mean'], t['sd']) - t['p']
def rtruncnorm(n, a, b, mean, sd):
if np.nan == n:
raise ValueError("n is NA - aborting.");
n_a = len(a) if isinstance(a, np.ndarray) else 1
n_b = len(b) if isinstance(b, np.ndarray) else 1
n_mean = len(mean) if isinstance(mean, np.ndarray) else 1
n_sd = len(sd) if isinstance(sd, np.ndarray) else 1
nn = max(n, max(max(n_a, n_b), max(n_mean, n_sd)));
ret = np.zeros(n)
# GetRNGstate();
for i in range(nn):
ca = a[i % n_a] if isinstance(a, np.ndarray) else a
cb = b[i % n_b] if isinstance(b, np.ndarray) else b
cmean = mean[i % n_mean] if isinstance(mean, np.ndarray) else mean
csd = sd[i % n_sd] if isinstance(sd, np.ndarray) else sd
if (np.isfinite(ca) and np.isfinite(cb)):
ret[i] = r_truncnorm(ca, cb, cmean, csd)
elif (-np.inf == ca and np.isfinite(cb)):
ret[i] = r_righttruncnorm(cb, cmean, csd)
elif (np.isfinite(ca) and np.inf == cb):
ret[i] = r_lefttruncnorm(ca, cmean, csd)
elif (-np.inf == ca and np.inf == cb):
ret[i] = rnorm(cmean, csd)
else:
ret[i] = np.nan;
return ret;
def r_truncnorm(a, b, mean, sd):
alpha = (a - mean) / sd;
beta = (b - mean) / sd;
rv = norm(0.0, 1.0)
phi_a = rv.pdf(alpha)
phi_b = rv.pdf(beta)
if (beta <= alpha):
return NA_REAL;
elif (alpha <= 0 and 0 <= beta): # /* 2 */
if (phi_a <= t1 or phi_b <= t1): # /* 2 (a) */
return mean + sd * nrs_a_b(alpha, beta)
else: # /* 2 (b) */
return mean + sd * urs_a_b(alpha, beta)
elif (alpha > 0):# { /* 3 */
if (phi_a / phi_b <= t2):# /* 3 (a) */
return mean + sd * urs_a_b(alpha, beta)
else:
if (alpha < t3): # /* 3 (b) */
return mean + sd * hnrs_a_b(alpha, beta)
else: # { /* 3 (c) */
return mean + sd * ers_a_b(alpha, beta)
else: # /* 3s */
if (phi_b / phi_a <= t2): # /* 3s (a) */
return mean - sd * urs_a_b(-beta, -alpha)
else:
if (beta > -t3): # /* 3s (b) */
return mean - sd * hnrs_a_b(-beta, -alpha)
else: # /* 3s (c) */
return mean - sd * ers_a_b(-beta, -alpha)
#/* Normal rejection sampling (a,b) */
def nrs_a_b(a, b):
x = -DBL_MAX
while(x < a or x > b):
x = rnorm(0, 1)
return x
#/* Uniform rejection sampling */
def urs_a_b(a, b):
rv = norm(0, 1)
phi_a = rv.pdf(a)
x = 0.0
u = 0.0
#/* Upper bound of normal density on [a, b] */
ub = a < 0 and b > 0 if M_1_SQRT_2PI else phi_a
x = runif(a, b)
while runif(0, 1) * ub > rv.pdf(x):
x = runif(a, b)
return x
#/* Half-normal rejection sampling */
def hnrs_a_b(a, b):
x = a - 1.0
while x < a or x > b:
x = rnorm(0, 1)
x = abs(x)
return x
#/* Exponential rejection sampling (a,b) */
def ers_a_b(a, b):
ainv = 1.0 / a
x = rexp(ainv) + a # /* rexp works with 1/lambda */
rho = np.exp(-0.5 * np.pow((x-a), 2))
while (runif(0, 1) > rho or x > b):
x = rexp(ainv) + a # /* rexp works with 1/lambda */
rho = np.exp(-0.5 * np.pow((x-a), 2))
return x
if __name__ == '__main__':
seq = np.arange # R->Py
print('ptruncnorm')
y = seq(0, 1, 0.1)
print(y)
print(ptruncnorm(y, 0, 1, 0.5, 0.1).tolist())
print([0.000000e+00,3.138461e-05, 1.349612e-03, 2.274986e-02, 1.586551e-01, 5.000000e-01, 8.413449e-01, 9.772501e-01, 9.986504e-01, 9.999686e-01, 1.000000e+00])
y = seq(0, 10, 1)
print(y)
print(ptruncnorm(y, 3, 8, 5, 0.3).tolist())
print([0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0004290603, 0.5000000000, 0.9995709397, 1.0000000000, 1.0000000000, 1.0000000000, 1.0000000000])
print('qtruncnorm')
x = seq(0,1,0.1)
print(x)
print(qtruncnorm(x, 0, 2, 1, 0.2).tolist())
print([0.0000000, 0.7436899, 0.8316759, 0.8951200, 0.9493306, 1.0000000, 1.0506694, 1.1048800, 1.1683241, 1.2563101, 2.0000000])
x = rtruncnorm(10000, 1, 5, 2.5, 1)
import seaborn as sns
import matplotlib.pyplot as plt
sns.distplot(x)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment