Skip to content

Instantly share code, notes, and snippets.

@T-B-F
Last active October 25, 2018 08:04
Show Gist options
  • Save T-B-F/11b9527687d07eb903614fb59455073c to your computer and use it in GitHub Desktop.
Save T-B-F/11b9527687d07eb903614fb59455073c to your computer and use it in GitHub Desktop.
simple parallel JSD using joblib
from sklearn.metrics.pairwise import pairwise_distances
import numpy as np
import sklearn
print(sklearn.__version__)
print(np.__version__)
def posdef_check_value(d):
d[np.isnan(d)]=0
d[np.isinf(d)]=0
## Distance functions
def KL(a, b):
""" compute the KL distance
"""
if a.ndim == 1 and b.ndim == 1:
d = a * np.log(a/b)
posdef_check_value(d)
res = np.sum(d)
elif a.ndim == 2 and b.ndim == 2:
X, Y = check_pairwise_arrays(a, b)
X = X[:,np.newaxis]
d = X * np.log(X/Y)
posdef_check_value(d)
res = np.sum(d, axis=2).T
else:
print("Dimension erro in KL, a={}, b={}".format(a.ndim, b.ndim), file=sys.stderr)
sys.exit(1)
return res
def JSD(a, b):
""" Compute JSD distance
"""
if a.ndim == 1 and b.ndim == 1:
h = 0.5 * (a + b)
d = 0.5 * (KL(a, h) + KL(b, h))
elif a.ndim==2 and b.ndim == 1:
h = 0.5 * (a[np.newaxis,:] + b)
d1 = a[np.newaxis,:] * np.log(a[np.newaxis,:]/h)
posdef_check_value(d1)
d1 = np.sum(d1, axis=2)
d2 = b * np.log(b/h)
posdef_check_value(d2)
d2 = np.sum(d2, axis=2)
d = 0.5 * (d1 + d2)
else:
h = 0.5 * (a[np.newaxis,:] + b[:, np.newaxis])
d1 = a[np.newaxis,:] * np.log(a[np.newaxis,:]/h)
posdef_check_value(d1)
d1 = np.sum(d1, axis=2)
d2 = b[:, np.newaxis] * np.log(b[:, np.newaxis]/h)
posdef_check_value(d2)
d2 = np.sum(d2, axis=2)
d = 0.5 * (d1 + d2)
#d = d.T
return d
a = np.random.random(10)
b = np.random.random(10)
frequencies = np.random.random((100, 10))
distances = pairwise_distances(frequencies, metric=JSD, n_jobs=2)
print(distances.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment