Skip to content

Instantly share code, notes, and snippets.

@javipus
Last active April 14, 2024 14:05
Show Gist options
  • Save javipus/24d07319fac761c65686198dd9897ebf to your computer and use it in GitHub Desktop.
Save javipus/24d07319fac761c65686198dd9897ebf to your computer and use it in GitHub Desktop.
Kullback-Leibler Divergence Estimator.
__pycache__/*

Python implementation of the Kullback-Leibler divergence estimator described in this paper. It relies on scikit-learn for k-NN.

from __future__ import print_function
from __future__ import division
import os, sys, time, copy, re
sys.dont_write_bytecode = True
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors as kNN
def klDivergence(x, y, logBase = 2, returnk = False, **kwds):
# TODO the estimator converges a.s. to the actual KL divergence, but the convergence rate looks terrible
# TODO simple tests with 1-D gaussians suggest (this implementation of) the estimator is biased
# TODO does k-NN with k>1 density estimation improve results (cost-effectively, given it will take more computation time)?
"""
Estimate KL divergence D(P||Q) between unknown distributions P and Q using iid samples X_i~P and Y_i~Q.
The estimator is the one developed in http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.422.5121&rep=rep1&type=pdf
@param x: List of samples from *true* distribution, P.
@param y: List of samples from *approximate* distribution, Q.
@param logBase: Base of the logarithm to compute the KL divergence (changes the units in which entropy is expressed).
@param returnk: If True, returns k used for k-NN estimation as second argument.
@param kwds: Keyword arguments to be passed to sklearn.neighbors.NearestNeighbors constructor.
"""
n, m = len(x), len(y)
_log = lambda x, b: np.log2(x) / np.log2(b) if x != 0 else 0
_log = np.vectorize(_log)
if not hasattr(x[0], '__len__'): # 1-D data
x, y = map(lambda z: np.array(z).reshape(-1, 1), (x, y))
d = x.shape[1]
if d != y.shape[1]:
raise Exception('Dimension mismatch: X{} != Y{}'.format(x.shape, y.shape))
# nnDistX is zero if a point is sampled >= twice and that's a problem because it goes in a denominator
# Remember what you're trying to do here: calculate the distance to the 1-NN so that the empirical estimate of
# the pdf is ~1/dist_1NN. Then it would make sense to find the smallest k for which dist_kNN > 0 and estimate
# the pdf as ~k/dist_kNN - but there's no way of knowing this upfront so you have to loop through the particles
# and see which ones have been sampled more than once and ugh
# And that's the reason for that while loop
k = 1 # start with 1-NN
while True:
knnX = kNN(n_neighbors = k+1, **kwds).fit(x)
nnDistX = knnX.kneighbors(x)[0][:, k]
if not nnDistX.all():
k += 1
else:
break
knnY = kNN(n_neighbors = k, **kwds).fit(y)
nnDistY = knnY.kneighbors(x)[0][:, k-1]
kl = (d/n) * sum(_log(nnDistY/nnDistX, logBase)) + _log((m/(n-1)), logBase)
if returnk:
return kl, k
else:
return kl
def kl_th(p, q, logBase = 2):
"""
Analytical formula for the Kullback-Leibler divergence in simple cases.
@param p, q: Dictionary containing model specification with keys:
- 'model': Only 'normal' for now - TODO: include exponentials, etc.
- Other keys: model parameters
TODO: This argument should be something standard. Maybe pass an instance of scipy.stats.<distribution_name>?
@param logBase: Entropy measured in:
- Bits if log2.
- Nats if ln.
- What's the name for log10 again?
"""
p_type, q_type = p['model'], q['model']
tr = np.trace
inv = np.linalg.inv
det = np.linalg.det
_log = lambda x, b: np.log2(x) / np.log2(b)
if p_type == q_type == 'normal':
mu_p, mu_q = p['mu'], q['mu']
sigma_p, sigma_q = p['sigma'], q['sigma']
dmu = mu_q - mu_p
d = len(dmu)
if d > 1:
kl = .5 * (tr(inv(sigma_q) * sigma_p) + np.dot(dmu, np.dot(inv(sigma_q), dmu)) - d + _log(det(sigma_q)/det(sigma_p), logBase))
else:
kl = .5 * (sigma_p/sigma_q + (dmu**2)/sigma_q - d + _log(sigma_q/sigma_p, logBase))
elif p_type == 'normal' and q_type == 'exponential':
raise Exception('Working on it!')
elif p_type == 'exponential' and q_type == 'normal':
raise Exception('Working on it!')
else:
raise Exception('Distributions {} and/or {} unknown!'.format(p_type, q_type))
return kl
from warnings import warn
import os, re, time, sys, copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from kl_div_estimator import klDivergence, kl_th
kl = klDivergence
def runToy(d = 1):
p = {
'model': 'normal',
'mu': np.zeros(d),
'sigma': 5*np.eye(d)
}
q0 = {
'model': 'normal',
'mu': np.array([0]),
'sigma': np.array([5])
}
q1 = {
'model': 'normal',
'mu': np.array([0]),
'sigma': np.array([10])
}
n = 1000
m = 1000
kl_emp = []
for k in range(m):
if d == 1:
x = p['sigma'][0] ** (.5) * np.random.randn(n) + p['mu']
y = q0['sigma'][0] ** (.5) * np.random.randn(n) + q0['mu']
z = q1['sigma'][0] ** (.5) * np.random.randn(n) + q1['mu']
else:
x = np.random.multivariate_normal(p['mu'], p['sigma'], n)
y = np.random.multivariate_normal(q0['mu'], q0['sigma'], n)
z = np.random.multivariate_normal(q1['mu'], q1['sigma'], n)
kl_emp.append(list(map(lambda yy: kl(x, yy), (y, z))))
kl_emp = np.array(kl_emp)
kl_real = list(map(lambda q: kl_th(p, q), (q0, q1)))
plt.hist(kl_emp[:, 0], color = 'b', alpha = .25)
plt.hist(kl_emp[:, 1], color = 'r', alpha = .25)
plt.axvline(np.mean(kl_emp[:, 0]), ls = '--', color = 'b', lw = 3, label = None)
plt.axvline(np.mean(kl_emp[:, 1]), ls = '--', color = 'r', lw = 3, label = None)
plt.axvline(kl_real[0], c = 'b', lw = 3, label = r'D(P||$Q_0$)')
plt.axvline(kl_real[1], c = 'r', lw = 3, label = r'D(P||$Q_1$)')
plt.xlabel('KL Divergence')
plt.ylabel('Count')
plt.legend()
plt.tight_layout()
plt.show()
return kl_real, kl_emp
if __name__ == '__main__':
klth, kle = runToy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment