Skip to content

Instantly share code, notes, and snippets.

@WarrenWeckesser
Created March 26, 2022 11:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save WarrenWeckesser/636b537ee889679227d53543d333a720 to your computer and use it in GitHub Desktop.
Save WarrenWeckesser/636b537ee889679227d53543d333a720 to your computer and use it in GitHub Desktop.
Compute stats for the truncnorm distribution using mpmath
import mpmath
mpmath.mp.dps = 80
def truncnorm_delta_cdf(a, b):
if a > 0:
delta = mpmath.ncdf(-a) - mpmath.ncdf(-b)
else:
delta = mpmath.ncdf(b) - mpmath.ncdf(a)
return delta
def truncnorm_pdf(x, a, b):
if a >= b:
raise ValueError("'a' must be less than 'b'")
delta_cdf = truncnorm_delta_cdf(a, b)
if delta_cdf == 0:
raise RuntimeError("delta_cdf is 0; try increasing mpmath.mp.dps.")
return mpmath.npdf(x) / delta_cdf
def truncnorm_stats(a, b):
a = mpmath.mpf(a)
b = mpmath.mpf(b)
pa = truncnorm_pdf(a, a, b)
pb = truncnorm_pdf(b, a, b)
# Fix multiplication of p(a)*a and p(b)*b when a or b
# is infinite by setting a or b to 0, resp. Otherwise
# the product 0*inf gives nan.
if b == mpmath.inf:
b = 0
if a == -mpmath.inf:
a = 0
# m# are moments about 0 (i.e. noncentral moments)
# mu# are moments about the mean (i.e. central moments)
m1 = pa - pb
mu = m1
m2 = 1 + pa*a - pb*b
mu2 = (a - mu)*pa - (b - mu)*pb + 1
m3 = 2*m1 + pa*a**2 - pb*b**2
m4 = 3*m2 + pa*a**3 - pb*b**3
mu3 = m3 + m1 * (-3*m2 + 2*m1**2)
g1 = mu3 / mpmath.power(mu2, 1.5)
mu4 = m4 + m1*(-4*m3 + 3*m1*(2*m2 - m1**2))
g2 = mu4 / mu2**2 - 3
return mu, mu2, g1, g2
def truncnorm_stats_quad(a, b):
if mpmath.isinf(a) or mpmath.isinf(b):
raise ValueError("truncnorm_stats_quad requires both 'a' and 'b' "
"to be finite.")
# Use mpmath.quad to compute mu3 and mu4.
# This is slow, but it provides a check for mistakes in the
# implementation of truncnorm_stats(a, b).
a = mpmath.mpf(a)
b = mpmath.mpf(b)
pa = truncnorm_pdf(a, a, b)
pb = truncnorm_pdf(b, a, b)
# mu is the mean.
# mu# are moments about the mean (i.e. central moments).
mu = pa - pb
mu2 = (a - mu)*pa - (b - mu)*pb + 1
mu3 = mpmath.quad(lambda t: truncnorm_pdf(t, a, b)*(t - mu)**3, [a, b])
g1 = mu3 / mpmath.power(mu2, 1.5)
mu4 = mpmath.quad(lambda t: truncnorm_pdf(t, a, b)*(t - mu)**4, [a, b])
g2 = mu4 / mu2**2 - 3
return mu, mu2, g1, g2
def print_table(intervals):
h = ['a', 'b', 'mean', 'variance', 'skewness', 'excess kurtosis']
print(f"{h[0]:>6} {h[1]:>6} "
f"{h[2]:>24s} {h[3]:>24s} {h[4]:>24s} {h[5]:>24s}")
for a, b in intervals:
print(f"{str(a):>6} {str(b):>6}", end='')
for stat in [float(t) for t in truncnorm_stats(a, b)]:
print(f" {repr(stat):>24s}", end='')
print()
def parstr(value):
if mpmath.isinf(value):
sgn = '' if value >= 0 else '-'
text = sgn + "np.inf"
else:
text = str(value)
return text
def print_test_code(intervals):
print("# Test data for the truncnorm stats() method.")
print("# The data in each row is:")
print("# a, b, mean, variance, skewness, excess kurtosis.")
print("_truncnorm_stats_data = [")
for a, b in intervals:
s = f" [{parstr(a)}, {parstr(b)},"
print(s)
indent = " "
valstrs = [str(float(t)) for t in truncnorm_stats(a, b)]
line = indent + ', '.join(valstrs) + '],'
if len(line) > 79:
line = indent + (',\n' + indent).join(valstrs) + '],'
print(line)
print("]")
if __name__ == "__main__":
check_intervals = [
[-30, 30],
[-10, 10],
[-3, 3],
[-2, 2],
[0, mpmath.inf],
[-mpmath.inf, 0],
[-1, 3],
[-3, 1],
[-10, -9],
[-20, -19],
[-30, -29],
[-40, -39],
[39, 40],
]
# print_test_code(check_intervals)
print_table(check_intervals)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment