Skip to content

Instantly share code, notes, and snippets.

@basaks
Forked from atabakd/kl.py
Last active May 25, 2022 04:42
Show Gist options
  • Save basaks/0f420bc097c608eb7e2e60564de67de6 to your computer and use it in GitHub Desktop.
Save basaks/0f420bc097c608eb7e2e60564de67de6 to your computer and use it in GitHub Desktop.
KL divergence for multivariate samples
# https://mail.python.org/pipermail/scipy-user/2011-May/029521.html
import numpy as np
from scipy.spatial import cKDTree as KDTree
from scipy.special import rel_entr
def KLdivergence(x, y):
"""Compute the Kullback-Leibler divergence between two multivariate samples.
Parameters
----------
x : 2D array (n,d)
Samples from distribution P, which typically represents the true
distribution.
y : 2D array (m,d)
Samples from distribution Q, which typically represents the approximate
distribution.
Returns
-------
out : float
The estimated Kullback-Leibler divergence D(P||Q).
References
----------
Pérez-Cruz, F. Kullback-Leibler divergence estimation of
continuous distributions IEEE International Symposium on Information
Theory, 2008.
"""
# Check the dimensions are consistent
x = np.atleast_2d(x)
y = np.atleast_2d(y)
n,d = x.shape
m,dy = y.shape
assert(d == dy)
# Build a KD tree representation of the samples and find the nearest neighbour
# of each point in x.
xtree = KDTree(x)
ytree = KDTree(y)
# Get the first two nearest neighbours for x, since the closest one is the
# sample itself.
r = xtree.query(x, k=2, eps=.01, p=2)[0][:,1]
s = ytree.query(x, k=1, eps=.01, p=2)[0]
# when there are possible issues with data, deal with it somehow.
is_finite = np.isfinite(r) & np.isfinite(s) & (r > 1e-10) & (s > 1e-10)
r = r[is_finite]
s = s[is_finite]
kldiv = np.sum(rel_entr(r, s))
return kldiv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment