Skip to content

Instantly share code, notes, and snippets.

@lars-von-buchholtz
Last active June 25, 2021 16:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lars-von-buchholtz/636f542ce8d93d5a14ae52a6c538ced5 to your computer and use it in GitHub Desktop.
Save lars-von-buchholtz/636f542ce8d93d5a14ae52a6c538ced5 to your computer and use it in GitHub Desktop.
Kullback-Leibler divergence for multivariate samples from continuous distributions
library(RANN)
kl <- function(X,Y) {
"""Estimate the Kullback-Leibler divergence between two multivariate samples.
adapted for R from python code at https://gist.github.com/atabakd/ed0f7581f8510c8587bc2f41a094b518
as described in Fernando Pérez-cruz, Kullback-Leibler Divergence Estimation of Continuous Distributions,
Proceedings of IEEE International Symposium on Information Theory, 2008, 1666--1670
Parameters
----------
X : 2D matrix (n,d)
Samples from distribution P, which typically represents the true
distribution.
Y : 2D matrix (m,d)
Samples from distribution Q, which typically represents the approximate
distribution.
Returns
-------
out : float
The estimated Kullback-Leibler divergence D(P||Q).
# get important dimensions
d <- ncol(X) # number of dimensions, must be the same in X and Y
n <- nrow(X) # number of samples in X
m <- nrow(Y) # number of samples in Y
# get distances to nearest neighbors from kdTree using nn2 from the RANN package
r <- nn2(X,X, k=2, eps=.01)[[2]][,2] # get 2 closest neighbors, then take the second (the closest is the point itself) to get n x 1 matrix
s <- nn2(Y,X, k=1, eps=.01)[[2]] # also n x 1 matrix
# There is a mistake in the paper. In Eq. 14, the right side misses a negative sign
# on the first term of the right hand side.
return (- sum(log(r/s)) * d / n + log(m / (n - 1.)))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment