Created
August 30, 2017 12:39
-
-
Save Jarino/cb6d9b39abcf773a1fb0e9a90ee67db9 to your computer and use it in GitHub Desktop.
Cauchy-Schwarz divergence
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import sqrt | |
from math import log | |
from scipy.stats import gaussian_kde | |
def cs_divergence(p1, p2): | |
""" | |
Calculates the Cauchy-Schwarz divergence between two probabilities distribution. CS divergence is symmetrical, | |
hence the order of the arguments does not matter. The result is from interval [0, infinity], | |
where 0 is obtained when the two probabilities distributions are same. | |
Args: | |
p1 (numpy array): first pdfs | |
p2 (numpy array): second pdfs | |
Returns: | |
float: CS divergence | |
""" | |
r = range(0, p1.shape[0]) | |
p1_kernel = gaussian_kde(p1) | |
p2_kernel = gaussian_kde(p2) | |
p1_computed = p1_kernel(r) | |
p2_computed = p2_kernel(r) | |
numerator = sum(p1_computed * p2_computed) | |
denominator = sqrt(sum(p1_computed ** 2) * sum(p2_computed**2)) | |
return -log(numerator/denominator) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment