Skip to content

Instantly share code, notes, and snippets.

@davidlenz
Last active December 5, 2020 07:05
Show Gist options
  • Save davidlenz/879735d64f8b0be570a684ac5fd79d3b to your computer and use it in GitHub Desktop.
Save davidlenz/879735d64f8b0be570a684ac5fd79d3b to your computer and use it in GitHub Desktop.
Implementation of Jensen-Shannon-Divergence based on https://github.com/scipy/scipy/issues/8244
import numpy as np
from scipy.stats import entropy
def js(p, q):
p = np.asarray(p)
q = np.asarray(q)
# normalize
p /= p.sum()
q /= q.sum()
m = (p + q) / 2
return (entropy(p, m) + entropy(q, m)) / 2
@Darthholi
Copy link

Just for those who land here looking for jensen shannon distance (using monte carlo integration) between two distributions:

def distributions_js(distribution_p, distribution_q, n_samples=10 ** 5):
    # jensen shannon divergence. (Jensen shannon distance is the square root of the divergence)
    # all the logarithms are defined as log2 (because of information entrophy)
    X = distribution_p.rvs(n_samples)
    p_X = distribution_p.pdf(X)
    q_X = distribution_q.pdf(X)
    log_mix_X = np.log2(p_X + q_X)

    Y = distribution_q.rvs(n_samples)
    p_Y = distribution_p.pdf(Y)
    q_Y = distribution_q.pdf(Y)
    log_mix_Y = np.log2(p_Y + q_Y)

    return (np.log2(p_X).mean() - (log_mix_X.mean() - np.log2(2))
            + np.log2(q_Y).mean() - (log_mix_Y.mean() - np.log2(2))) / 2

print("should be different:")
print(distributions_js(st.norm(loc=10000), st.norm(loc=0)))
print("should be same:")
print(distributions_js(st.norm(loc=0), st.norm(loc=0)))

https://stats.stackexchange.com/questions/345915/trying-to-implement-the-jensen-shannon-divergence-for-multivariate-gaussians/419421#419421

@davidlenz
Copy link
Author

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment