Skip to content

Instantly share code, notes, and snippets.

@SherazKhan
Forked from kailashbuki/dcor.py
Created June 20, 2017 23:42
Show Gist options
  • Save SherazKhan/4b2fe45c50a402dd73990c98450b2c89 to your computer and use it in GitHub Desktop.
Save SherazKhan/4b2fe45c50a402dd73990c98450b2c89 to your computer and use it in GitHub Desktop.
Computes the distance correlation between two matrices in Python.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Computes the distance correlation between two matrices.
https://en.wikipedia.org/wiki/Distance_correlation
"""
import numpy as np
from scipy.spatial.distance import pdist, squareform
__author__ = "Kailash Budhathoki"
__email__ = "kbudhath@mpi-inf.mpg.de"
__copyright__ = "Copyright (c) 2016"
__license__ = "MIT"
def dcov(X, Y):
"""Computes the distance covariance between matrices X and Y.
"""
n = X.shape[0]
XY = np.multiply(X, Y)
cov = np.sqrt(XY.sum()) / n
return cov
def dvar(X):
"""Computes the distance variance of a matrix X.
"""
return np.sqrt(np.sum(X ** 2 / X.shape[0] ** 2))
def cent_dist(X):
"""Computes the pairwise euclidean distance between rows of X and centers
each cell of the distance matrix with row mean, column mean, and grand mean.
"""
M = squareform(pdist(X)) # distance matrix
rmean = M.mean(axis=1)
cmean = M.mean(axis=0)
gmean = rmean.mean()
R = np.tile(rmean, (M.shape[0], 1)).transpose()
C = np.tile(cmean, (M.shape[1], 1))
G = np.tile(gmean, M.shape)
CM = M - R - C + G
return CM
def dcor(X, Y):
"""Computes the distance correlation between two matrices X and Y.
X and Y must have the same number of rows.
>>> X = np.matrix('1;2;3;4;5')
>>> Y = np.matrix('1;2;9;4;4')
>>> dcor(X, Y)
0.76267624241686649
"""
assert X.shape[0] == Y.shape[0]
A = cent_dist(X)
B = cent_dist(Y)
dcov_AB = dcov(A, B)
dvar_A = dvar(A)
dvar_B = dvar(B)
dcor = 0.0
if dvar_A > 0.0 and dvar_B > 0.0:
dcor = dcov_AB / np.sqrt(dvar_A * dvar_B)
return dcor
if __name__ == "__main__":
X = np.matrix('1;2;3;4;5')
Y = np.matrix('1;2;9;4;4')
print dcor(X, Y)
@Surfcat888
Copy link

Tried this code, got the following error:

File "", line 79
print dcor(X, Y)
^
SyntaxError: invalid syntax

Any ideas on a solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment