Skip to content

Instantly share code, notes, and snippets.

@koshian2
Forked from atabakd/kl.py
Created June 19, 2023 19:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/5666b23a5e87f5a1d3470c90e7a50591 to your computer and use it in GitHub Desktop.
Save koshian2/5666b23a5e87f5a1d3470c90e7a50591 to your computer and use it in GitHub Desktop.
KL divergence for multivariate samples
# https://mail.python.org/pipermail/scipy-user/2011-May/029521.html
import numpy as np
def KLdivergence(x, y):
"""Compute the Kullback-Leibler divergence between two multivariate samples.
Parameters
----------
x : 2D array (n,d)
Samples from distribution P, which typically represents the true
distribution.
y : 2D array (m,d)
Samples from distribution Q, which typically represents the approximate
distribution.
Returns
-------
out : float
The estimated Kullback-Leibler divergence D(P||Q).
References
----------
Pérez-Cruz, F. Kullback-Leibler divergence estimation of
continuous distributions IEEE International Symposium on Information
Theory, 2008.
"""
from scipy.spatial import cKDTree as KDTree
# Check the dimensions are consistent
x = np.atleast_2d(x)
y = np.atleast_2d(y)
n,d = x.shape
m,dy = y.shape
assert(d == dy)
# Build a KD tree representation of the samples and find the nearest neighbour
# of each point in x.
xtree = KDTree(x)
ytree = KDTree(y)
# Get the first two nearest neighbours for x, since the closest one is the
# sample itself.
r = xtree.query(x, k=2, eps=.01, p=2)[0][:,1]
s = ytree.query(x, k=1, eps=.01, p=2)[0]
# There is a mistake in the paper. In Eq. 14, the right side misses a negative sign
# on the first term of the right hand side.
return -np.log(r/s).sum() * d / n + np.log(m / (n - 1.))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment