Skip to content

Instantly share code, notes, and snippets.

@mycarta
Forked from satra/distcorr.py
Created September 24, 2019 15:27
Show Gist options
  • Save mycarta/c97e8c470bf00952961c0e29d7e91b24 to your computer and use it in GitHub Desktop.
Save mycarta/c97e8c470bf00952961c0e29d7e91b24 to your computer and use it in GitHub Desktop.
Distance Correlation in Python
from scipy.spatial.distance import pdist, squareform
import numpy as np
from numbapro import jit, float32
def distcorr(X, Y):
""" Compute the distance correlation function
>>> a = [1,2,3,4,5]
>>> b = np.array([1,2,9,4,4])
>>> distcorr(a, b)
0.762676242417
"""
X = np.atleast_1d(X)
Y = np.atleast_1d(Y)
if np.prod(X.shape) == len(X):
X = X[:, None]
if np.prod(Y.shape) == len(Y):
Y = Y[:, None]
X = np.atleast_2d(X)
Y = np.atleast_2d(Y)
n = X.shape[0]
if Y.shape[0] != X.shape[0]:
raise ValueError('Number of samples must match')
a = squareform(pdist(X))
b = squareform(pdist(Y))
A = a - a.mean(axis=0)[None, :] - a.mean(axis=1)[:, None] + a.mean()
B = b - b.mean(axis=0)[None, :] - b.mean(axis=1)[:, None] + b.mean()
dcov2_xy = (A * B).sum()/float(n * n)
dcov2_xx = (A * A).sum()/float(n * n)
dcov2_yy = (B * B).sum()/float(n * n)
dcor = np.sqrt(dcov2_xy)/np.sqrt(np.sqrt(dcov2_xx) * np.sqrt(dcov2_yy))
return dcor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment