Skip to content

Instantly share code, notes, and snippets.

@zhiyzuo
Last active May 12, 2022 10:58
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save zhiyzuo/f80e2b1cfb493a5711330d271a228a3d to your computer and use it in GitHub Desktop.
Save zhiyzuo/f80e2b1cfb493a5711330d271a228a3d to your computer and use it in GitHub Desktop.
Jensen-Shannon Divergence in Python
import numpy as np
import scipy as sp
def jsd(p, q, base=np.e):
'''
Implementation of pairwise `jsd` based on
https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence
'''
## convert to np.array
p, q = np.asarray(p), np.asarray(q)
## normalize p, q to probabilities
p, q = p/p.sum(), q/q.sum()
m = 1./2*(p + q)
return sp.stats.entropy(p,m, base=base)/2. + sp.stats.entropy(q, m, base=base)/2.
@ano302
Copy link

ano302 commented Apr 23, 2018

Please be aware that this implementation assumes p and q are already normalized. Otherwise you will get wrong results and may not even notice.

@davidlenz
Copy link

A simple fix could be to add

    p /= p.sum()
    q /= q.sum()

which gives in total


from scipy.stats import entropy
def js(p, q):
    p = np.asarray(p)
    q = np.asarray(q)
   # normalize
    p /= p.sum()
    q /= q.sum()
    m = (p + q) / 2
    return (entropy(p, m) + entropy(q, m)) / 2

@manjeetnagi
Copy link

What would be an optimum way to calculate the jsd when there is a large number of probability distribution? Say there is 1 data set with 10K probability distribution. I want to calculate the jsd of each of them with everything other. it will end up being roughly 10K*10K/2 computations. Is there any smart way to do it and avoid so many for loop or distributed processing. anything in numpy that can help?

@zhiyzuo
Copy link
Author

zhiyzuo commented Jul 9, 2018

thanks guys. i've updated the code to do the normalization first.

@manjeetnagi, i'm not really sure about an "optimal" way. if i were you, i would simply use joblib to parallelize the process.

@ayoubbenaissa
Copy link

see :
from scipy.spatial import distance
distance.jensenshannon(a,b)

@Darthholi
Copy link

Just for those who land here looking for jensen shannon distance (using monte carlo integration) between two distributions:

def distributions_js(distribution_p, distribution_q, n_samples=10 ** 5):
    # jensen shannon divergence. (Jensen shannon distance is the square root of the divergence)
    # all the logarithms are defined as log2 (because of information entrophy)
    X = distribution_p.rvs(n_samples)
    p_X = distribution_p.pdf(X)
    q_X = distribution_q.pdf(X)
    log_mix_X = np.log2(p_X + q_X)

    Y = distribution_q.rvs(n_samples)
    p_Y = distribution_p.pdf(Y)
    q_Y = distribution_q.pdf(Y)
    log_mix_Y = np.log2(p_Y + q_Y)

    return (np.log2(p_X).mean() - (log_mix_X.mean() - np.log2(2))
            + np.log2(q_Y).mean() - (log_mix_Y.mean() - np.log2(2))) / 2

print("should be different:")
print(distributions_js(st.norm(loc=10000), st.norm(loc=0)))
print("should be same:")
print(distributions_js(st.norm(loc=0), st.norm(loc=0)))

https://stats.stackexchange.com/questions/345915/trying-to-implement-the-jensen-shannon-divergence-for-multivariate-gaussians/419421#419421

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment