Skip to content

Instantly share code, notes, and snippets.

@chausies
Last active June 30, 2021 11:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chausies/011df759f167b17b5278264454fff379 to your computer and use it in GitHub Desktop.
Save chausies/011df759f167b17b5278264454fff379 to your computer and use it in GitHub Desktop.
Numerically stable and accurate PyTorch implementation of the log of the CDF of the standard normal distribution
# Numerically stable and accurate implementation of the natural logarithm
# of the cumulative distribution function (CDF) for the standard
# Normal/Gaussian distribution in PyTorch.
import matplotlib.pylab as P # replace this with numpy if you want
import torch as T
def norm_cdf(x):
return (1 + T.erf(x/P.sqrt(2)))/2
def log_norm_cdf_helper(x):
a = 0.344
b = 5.334
return ((1 - a)*x + a*x**2+b).sqrt()
def log_norm_cdf(x):
thresh = 3
out = x*0
l = x<-thresh
g = x>thresh
m = T.logical_and(x>=-thresh, x<=thresh)
out[m] = norm_cdf(x[m]).log()
out[l] = -(
(x[l]**2 + P.log(2*P.pi))/2 +
log_norm_cdf_helper(-x[l]).log()
)
out[g] = T.log1p(-
(-x[g]**2/2).exp()/P.sqrt(2*P.pi)/log_norm_cdf_helper(x[g])
)
return out
# Example plot
if __name__ == "__main__":
x = T.linspace(-10, 10, 25)
y = log_norm_cdf(x)
y2 = norm_cdf(x).log()
P.plot(x, y, label="improved")
P.plot(x, y2, label="original")
P.legend()
P.show()
@chausies
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment