Skip to content

Instantly share code, notes, and snippets.

@maxsei
Created July 28, 2021 14:58
Show Gist options
  • Save maxsei/07b02f65d4cca54d7f8d1bf83b5ae6d4 to your computer and use it in GitHub Desktop.
Save maxsei/07b02f65d4cca54d7f8d1bf83b5ae6d4 to your computer and use it in GitHub Desktop.
Shared Gaussian Integral. I'm pretty sure this just calculates the KL divergence...
def shared_gaussian_integral(u1, u2, s1, s2, ep=1e-8, max_iter=100):
def f(x):
return (
np.power(s1 * (x - u2), 2)
- np.power(s2 * (x - u1), 2)
+ 2 * np.power(s1 * s2, 2) * np.log(s2 / s1)
)
def df(x):
return 2 * (((s1 ** 2) * (x - u2)) - ((s2 ** 2) * (x - u1)))
def F(x, u, s):
return (1 + np.math.erf((x - u) / (s * np.sqrt(2)))) / 2
# If the standard deviations are the same there can be only one root.
# The interval of this calculation [0, 1].
if s1 == s2:
# Distributions have the same exact parameters.
if u1 == u2:
return np.nan
# Find root.
root = (np.power(u2, 2) - np.power(u1, 2)) / (2 * (u2 - u1))
if u1 < u2:
return F(root, u2, s2)
if u1 > u2:
return F(root, u1, s1)
raise ValueError("unreachable")
# Where gradient is zero.
zero_dfx = (u2 * np.power(s1, 2) - u1 * np.power(s2, 2)) / (
np.power(s1, 2) - np.power(s2, 2)
)
# Average of zero gradient difference and the two means.
furthest = max(
(u1 + s1, u1 - s1, u2 + s1, u2 - s1), key=lambda x: abs(zero_dfx - x)
)
# Initial guess of root.
x0 = (zero_dfx + furthest) / 2
# x0 = zero_dfx + 1
# Use newton rapson method to find root.
for _ in range(max_iter):
dx = f(x0) / df(x0)
# print(f"x0: {x0:4f}, f(x0): {f(x0):4f}, df(x0): {df(x0):4f}, dx: {dx:4f}")
x0 = x0 - dx
if abs(dx) < ep:
break
# Find other root.
roots = sorted((x0, zero_dfx + (zero_dfx - x0)))
# Calculate the integral on f over [-inf, root[0]], g over [root[0], root[1]],
# and f over [root[1], +inf], Where f is less than g on the interval [-inf,
# root[0]] and [root[1], +inf]. The domain of this calculation is on the
# interval [0, 1].
if s1 < s2:
# Then s1 is f and s2 is g.
return (
F(roots[0], u1, s1)
+ (F(roots[1], u2, s2) - F(roots[0], u2, s2))
+ (1 - F(roots[1], u1, s1))
)
# Else s2 is f and s1 is g.
return (
F(roots[0], u2, s2)
+ (F(roots[1], u1, s1) - F(roots[0], u1, s1))
+ (1 - F(roots[1], u2, s2))
)
area_overlap(0, 0, 1, 0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment