Skip to content

Instantly share code, notes, and snippets.

@suzusuzu
Last active November 21, 2019 23:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzusuzu/27894771e9fb1531f690d4fc4642b271 to your computer and use it in GitHub Desktop.
Save suzusuzu/27894771e9fb1531f690d4fc4642b271 to your computer and use it in GitHub Desktop.
An Implementation of Divergence Estimation for Multidimensional Densities Via k-Nearest-Neighbor Distance(https://www.princeton.edu/~kulkarni/Papers/Journals/j068_2009_WangKulVer_TransIT.pdf)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
def kl_d_norm(mu1, sigma1, mu2, sigma2):
d = np.log(sigma2/sigma1)
d += (sigma1**2 + (mu1 - mu2)**2) / (2 * sigma2**2)
d -= 1/2
return d
def kl_d(x, y, k=1):
nn_x = NearestNeighbors(n_neighbors=k+1).fit(x)
nn_y = NearestNeighbors(n_neighbors=k).fit(y)
n = x.shape[0]
m = y.shape[0]
dim = x.shape[1]
d = 0
for xi in x:
nu = nn_y.kneighbors(xi.reshape(1, -1))[0][0][k-1]
rho = nn_x.kneighbors(xi.reshape(1, -1))[0][0][k]
d += dim/n * np.log(nu/rho)
d += np.log(m/(n-1))
return d
mus = np.linspace(-1, 1, 100)
true_ds = []
est_ds = []
for mu in mus:
true_d = kl_d_norm(0, 1.0, mu, 1.0)
true_ds.append(true_d)
x = np.random.normal(mu, size=1000).reshape(-1, 1)
y = np.random.normal(size=1000).reshape(-1, 1)
est_d = kl_d(x, y, k=10)
est_ds.append(est_d)
plt.plot(mus, true_ds, label='true')
plt.plot(mus, est_ds, label='estimation')
plt.xlabel('mu')
plt.ylabel('Kullback–Leibler divergence')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment