Skip to content

Instantly share code, notes, and snippets.

@sitag
Forked from satra/distcorr.py
Created September 5, 2019 23:22
Show Gist options
  • Save sitag/1395e6264cf27e21539e0939c760af84 to your computer and use it in GitHub Desktop.
Save sitag/1395e6264cf27e21539e0939c760af84 to your computer and use it in GitHub Desktop.
Distance Correlation in Python
from scipy.spatial.distance import pdist, squareform
import numpy as np
from numbapro import jit, float32
def distcorr(X, Y):
""" Compute the distance correlation function
>>> a = [1,2,3,4,5]
>>> b = np.array([1,2,9,4,4])
>>> distcorr(a, b)
0.762676242417
"""
X = np.atleast_1d(X)
Y = np.atleast_1d(Y)
if np.prod(X.shape) == len(X):
X = X[:, None]
if np.prod(Y.shape) == len(Y):
Y = Y[:, None]
X = np.atleast_2d(X)
Y = np.atleast_2d(Y)
n = X.shape[0]
if Y.shape[0] != X.shape[0]:
raise ValueError('Number of samples must match')
a = squareform(pdist(X))
b = squareform(pdist(Y))
A = a - a.mean(axis=0)[None, :] - a.mean(axis=1)[:, None] + a.mean()
B = b - b.mean(axis=0)[None, :] - b.mean(axis=1)[:, None] + b.mean()
dcov2_xy = (A * B).sum()/float(n * n)
dcov2_xx = (A * A).sum()/float(n * n)
dcov2_yy = (B * B).sum()/float(n * n)
dcor = np.sqrt(dcov2_xy)/np.sqrt(np.sqrt(dcov2_xx) * np.sqrt(dcov2_yy))
return dcor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment