Skip to content

Instantly share code, notes, and snippets.

@dustinstansbury
Last active January 5, 2022 00:22
Show Gist options
  • Save dustinstansbury/a880568dca5fcbec087bc120a6779e13 to your computer and use it in GitHub Desktop.
Save dustinstansbury/a880568dca5fcbec087bc120a6779e13 to your computer and use it in GitHub Desktop.
from scipy import stats
from typing import Tuple, Union
import numpy as np
import pandas as pd
def xi_corr(
x: np.ndarray, y: np.ndarray, return_p_values: bool = False
) -> Union[float, Tuple[float, float]]:
"""Implementation of Chatterjee's correlation, which can capture both monotonic
and non-monotonic coupling between variables. The metric incorporates degree
of uncertainty, so as the length of x, y increases, so too does the confidence
in the correlation.
Parameters
----------
x, y: np.ndarray
Vectors of equal length
return_p_values : bool
Whether to return p-value for the correlation estimate. Note that we use
the asymptotic limit, assuming no ties amongst values.
Returns
-------
xi : float
The Chatterjee correlation
p_value : float (optional)
If return_p_values=True, returns the symptotic p-value for the correlation.
Just like any p-value, as the length of x, y increases, p-value decreases
if there is any relationship between the variables.
Note
----
Unexpected behavior, when x, y are linearly independent, p-value does not approach
.5, which is expected. But maybe DS just needs to read the paper more deeply.
References
----------
- "A new coefficient of correlation": https://arxiv.org/pdf/1909.10140.pdf
- R implementation: https://souravchatterjee.su.domains//xi.R
"""
def _preprocess(x, y):
# convert values to overlapping-valued DataFrame to take advantage
# of useful rank methods
df = pd.DataFrame(np.vstack([x.ravel(), y.ravel()]).T, columns=["x", "y"])
return df.dropna(axis=0, how="any")
xy = _preprocess(x, y)
# this is equivalent to .rank(method='random'), which isn't readily available
# in scipy.rankdata or pandas.rank -- ties are broken randomly
pi = xy.x.sample(frac=1).rank(method="first").reindex_like(xy)
n = len(xy)
f = xy.y.rank(method="max").values / n
order = np.argsort(pi)
f = f[order]
g = (-xy.y).rank(method="max").values / n
A1 = np.mean(np.abs(f[:-1] - f[1:])) * (n - 1) / (2 * n)
C = np.mean(g * (1 - g))
xi = 1 - (A1 / C)
if return_p_values:
pval = 1 - stats.norm.cdf(np.sqrt(n) * xi, scale=np.sqrt(2.0 / 5))
return xi, pval
return xi
def xi_corr2(x, y, return_p_values: bool = False):
"""Alternative implementation that's suppose to be >300x faster
Reference: https://github.com/czbiohub/xicor/issues/17
"""
n = x.size
xi = np.argsort(x, kind="quicksort")
y = y[xi]
_, b, c = np.unique(y, return_counts=True, return_inverse=True)
r = np.cumsum(c)[b]
_, b, c = np.unique(-y, return_counts=True, return_inverse=True)
l = np.cumsum(c)[b]
xi = 1 - n * np.abs(np.diff(r)).sum() / (2 * (l * (n - l)).sum())
if return_p_values:
pval = 1 - stats.norm.cdf(np.sqrt(n) * xi, scale=np.sqrt(2.0 / 5))
return xi, pval
return xi
@dustinstansbury
Copy link
Author

dustinstansbury commented Jan 4, 2022

Let's try out both algorithms

import numpy as np
from chaterjee_corr import xi_corr, xi_corr2
for n_obs in [10, 100, 1000, 10000]:
    s = f"N={n_obs}"
    print('-'*50)
    print(f"{s:^50}")
    print('-'*50)
    x = np.random.randn(n_obs)
    y = np.random.randn(n_obs)
    for algo in [xi_corr, xi_corr2]:

        # For fun add some nans
        print(f"{algo.__name__}(x, y)", algo(x, y, True))
        print(f"{algo.__name__}(x, x):", algo(x, x, True))
        print(f"{algo.__name__}(x, x^2):", algo(x, x ** 2, True))
        print(f"{algo.__name__}(x, x^3):", algo(x, x ** 3, True))
        print(f"{algo.__name__}(x, sin(x)):", algo(x, np.sin(x)))
        print(f"{algo.__name__}(x, sqrt(|x|)):", algo(x, np.abs(x) ** 0.5, True))

"""
--------------------------------------------------
                       N=10
--------------------------------------------------
xi_corr(x, y) (-0.15151515151515182, 0.7756475012431828)
xi_corr(x, x): (0.7272727272727273, 0.00013825695781910508)
xi_corr(x, x^2): (0.5454545454545455, 0.003193011641353549)
xi_corr(x, x^3): (0.7272727272727273, 0.00013825695781910508)
xi_corr(x, sin(x)): 0.696969696969697
xi_corr(x, sqrt(|x|)): (0.5454545454545455, 0.003193011641353549)
xi_corr2(x, y) (-0.1515151515151516, 0.7756475012431825)
xi_corr2(x, x): (0.7272727272727273, 0.00013825695781910508)
xi_corr2(x, x^2): (0.5454545454545454, 0.003193011641353549)
xi_corr2(x, x^3): (0.7272727272727273, 0.00013825695781910508)
xi_corr2(x, sin(x)): 0.696969696969697
xi_corr2(x, sqrt(|x|)): (0.5454545454545454, 0.003193011641353549)
--------------------------------------------------
                      N=100
--------------------------------------------------
xi_corr(x, y) (-0.16351635163516387, 0.9951369854408957)
xi_corr(x, x): (0.9702970297029703, 0.0)
xi_corr(x, x^2): (0.9411941194119412, 0.0)
xi_corr(x, x^3): (0.9702970297029703, 0.0)
xi_corr(x, sin(x)): 0.9567956795679569
xi_corr(x, sqrt(|x|)): (0.9411941194119412, 0.0)
xi_corr2(x, y) (-0.16351635163516343, 0.9951369854408956)
xi_corr2(x, x): (0.9702970297029703, 0.0)
xi_corr2(x, x^2): (0.9411941194119412, 0.0)
xi_corr2(x, x^3): (0.9702970297029703, 0.0)
xi_corr2(x, sin(x)): 0.9567956795679567
xi_corr2(x, sqrt(|x|)): (0.9411941194119412, 0.0)
--------------------------------------------------
                      N=1000
--------------------------------------------------
xi_corr(x, y) (-0.031245031245031196, 0.9408856310111086)
xi_corr(x, x): (0.997002997002997, 0.0)
xi_corr(x, x^2): (0.994011994011994, 0.0)
xi_corr(x, x^3): (0.997002997002997, 0.0)
xi_corr(x, sin(x)): 0.9941109941109941
xi_corr(x, sqrt(|x|)): (0.994011994011994, 0.0)
xi_corr2(x, y) (-0.031245031245031196, 0.9408856310111086)
xi_corr2(x, x): (0.997002997002997, 0.0)
xi_corr2(x, x^2): (0.994011994011994, 0.0)
xi_corr2(x, x^3): (0.997002997002997, 0.0)
xi_corr2(x, sin(x)): 0.9941109941109941
xi_corr2(x, sqrt(|x|)): (0.994011994011994, 0.0)
--------------------------------------------------
                     N=10000
--------------------------------------------------
xi_corr(x, y) (-0.002851350028513666, 0.6739468532599158)
xi_corr(x, x): (0.9997000299970003, 0.0)
xi_corr(x, x^2): (0.9994000899940009, 0.0)
xi_corr(x, x^3): (0.9997000299970003, 0.0)
xi_corr(x, sin(x)): 0.9991856799918568
xi_corr(x, sqrt(|x|)): (0.9994000899940009, 0.0)
xi_corr2(x, y) (-0.002851350028513444, 0.6739468532599031)
xi_corr2(x, x): (0.9997000299970003, 0.0)
xi_corr2(x, x^2): (0.9994000899940009, 0.0)
xi_corr2(x, x^3): (0.9997000299970003, 0.0)
xi_corr2(x, sin(x)): 0.9991856799918568
xi_corr2(x, sqrt(|x|)): (0.9994000899940009, 0.0)
"""

Results look the same for both implementations.

@dustinstansbury
Copy link
Author

dustinstansbury commented Jan 5, 2022

In [38]: %timeit xi_corr(x, y)
3.99 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [39]: %timeit xi_corr2(x, y)
1.86 ms ± 5.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Not 300x faster, but the second implementation does look a little over 2x faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment