Last active
October 25, 2018 08:04
-
-
Save T-B-F/11b9527687d07eb903614fb59455073c to your computer and use it in GitHub Desktop.
simple parallel JSD using joblib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics.pairwise import pairwise_distances | |
import numpy as np | |
import sklearn | |
print(sklearn.__version__) | |
print(np.__version__) | |
def posdef_check_value(d): | |
d[np.isnan(d)]=0 | |
d[np.isinf(d)]=0 | |
## Distance functions | |
def KL(a, b): | |
""" compute the KL distance | |
""" | |
if a.ndim == 1 and b.ndim == 1: | |
d = a * np.log(a/b) | |
posdef_check_value(d) | |
res = np.sum(d) | |
elif a.ndim == 2 and b.ndim == 2: | |
X, Y = check_pairwise_arrays(a, b) | |
X = X[:,np.newaxis] | |
d = X * np.log(X/Y) | |
posdef_check_value(d) | |
res = np.sum(d, axis=2).T | |
else: | |
print("Dimension erro in KL, a={}, b={}".format(a.ndim, b.ndim), file=sys.stderr) | |
sys.exit(1) | |
return res | |
def JSD(a, b): | |
""" Compute JSD distance | |
""" | |
if a.ndim == 1 and b.ndim == 1: | |
h = 0.5 * (a + b) | |
d = 0.5 * (KL(a, h) + KL(b, h)) | |
elif a.ndim==2 and b.ndim == 1: | |
h = 0.5 * (a[np.newaxis,:] + b) | |
d1 = a[np.newaxis,:] * np.log(a[np.newaxis,:]/h) | |
posdef_check_value(d1) | |
d1 = np.sum(d1, axis=2) | |
d2 = b * np.log(b/h) | |
posdef_check_value(d2) | |
d2 = np.sum(d2, axis=2) | |
d = 0.5 * (d1 + d2) | |
else: | |
h = 0.5 * (a[np.newaxis,:] + b[:, np.newaxis]) | |
d1 = a[np.newaxis,:] * np.log(a[np.newaxis,:]/h) | |
posdef_check_value(d1) | |
d1 = np.sum(d1, axis=2) | |
d2 = b[:, np.newaxis] * np.log(b[:, np.newaxis]/h) | |
posdef_check_value(d2) | |
d2 = np.sum(d2, axis=2) | |
d = 0.5 * (d1 + d2) | |
#d = d.T | |
return d | |
a = np.random.random(10) | |
b = np.random.random(10) | |
frequencies = np.random.random((100, 10)) | |
distances = pairwise_distances(frequencies, metric=JSD, n_jobs=2) | |
print(distances.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment