Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active May 31, 2023 10:19
Show Gist options
  • Save mtreviso/9e845d398ecb67de1c6e9a057501231b to your computer and use it in GitHub Desktop.
Save mtreviso/9e845d398ecb67de1c6e9a057501231b to your computer and use it in GitHub Desktop.
Kendall's tau implementations as proposed in the paper "Ties Matter: Modifying Kendall's Tau for Modern Metric Meta-Evaluation": https://arxiv.org/abs/2305.14324
import numpy as np
def compute_kendall_taus(h, m):
"""
Compute multiple variants of Kendall's Tau correlations between two rank arrays.
This function calculates several variants of Kendall's Tau correlations (tau_a, tau_b, tau_c,
tau_10, tau_13, tau_14, tau_23) between two input rank arrays. It employs vectorized operations
for computation efficiency and can handle tied ranks. This function is an implementation based on
the methods discussed in the paper "Ties Matter: Modifying Kendall's Tau for Modern Metric Meta-Evaluation",
available at https://arxiv.org/abs/2305.14324.
Parameters
----------
h, m : array-like
The input rank arrays. They should have the same shape. If arrays are not 1-D, they will be
flattened to 1-D.
Returns
-------
dict
A dictionary containing the computed Kendall's Tau correlations: tau_a, tau_b, tau_c, tau_10,
tau_13, tau_14, tau_23. Values close to 1 indicate strong agreement between the ranks,
whereas values close to -1 indicate strong disagreement.
Notes
-----
The current implementation has a time complexity of O(n^2) as it compares all possible pairs.
For a more efficient computation, one could sort the input rank arrays 'h' and 'm' before comparing
pairs, which would bring down the time complexity to O(n log n).
"""
# Ensure input arrays are numpy arrays and are 1D
h = np.array(h).flatten()
m = np.array(m).flatten()
# Check that arrays are of the same size
assert h.size == m.size, "Input arrays must be of the same size"
# Count of pairs
n = h.size
# Minimum count of unique values in h and m
k = min(len(np.unique(h)), len(np.unique(m)))
# Create 2D matrices of differences
h_diff = np.subtract.outer(h, h)
m_diff = np.subtract.outer(m, m)
# Create 2D matrices of sign of differences
h_sign = np.sign(h_diff)
m_sign = np.sign(m_diff)
# Masks for concordant, discordant, and tied pairs
concordant_mask = (h_sign == m_sign) & (h_sign != 0)
discordant_mask = (h_sign != m_sign) & (h_sign != 0) & (m_sign != 0)
ties_in_h_mask = (h_sign == 0) & (m_sign != 0)
ties_in_m_mask = (h_sign != 0) & (m_sign == 0)
ties_in_both_mask = (h_sign == 0) & (m_sign == 0)
# Sum of lower triangular elements without the main diagonal
C = np.tril(concordant_mask, -1).sum()
D = np.tril(discordant_mask, -1).sum()
Th = np.tril(ties_in_h_mask, -1).sum()
Tm = np.tril(ties_in_m_mask, -1).sum()
Thm = np.tril(ties_in_both_mask, -1).sum()
# Tau calculations
tau_a = (C - D) / (C + D + Th + Tm + Thm)
tau_b = (C - D) / ((C + D + Th) * (C + D + Tm)) ** 0.5
tau_c = (C - D) / (n*n * (k-1) / k)
tau_10 = (C - D - Tm) / (C + D + Tm)
tau_13 = (C - D) / (C + D)
tau_14 = (C - D) / (C + D + Tm)
tau_23 = (C + Thm - D - Th - Tm) / (C + D + Th + Tm + Thm)
# Accuracy calculation
acc_23 = (C + Thm) / (C + D + Th + Tm + Thm)
return {
'C': C,
'D': D,
'Th': Th,
'Tm': Tm,
'Thm': Thm,
'tau_a': tau_a,
'tau_b': tau_b,
'tau_c': tau_c,
'tau_10': tau_10,
'tau_13': tau_13,
'tau_14': tau_14,
'tau_23': tau_23,
'acc_23': acc_23
}
@mtreviso
Copy link
Author

Sample inputs:

h = np.array([0, 0, 0, 0, 1, 2])
m1 = np.array([0, 0, 0, 0, 2, 1])
m2 = np.array([0, 1, 2, 3, 4, 5])

Sample outputs:

>>> compute_kendall_taus(h, m1)
{'C': 8,
 'D': 1,
 'Th': 0,
 'Tm': 0,
 'Thm': 6,
 'tau_a': 0.4666666666666667,
 'tau_b': 0.7777777777777778,
 'tau_c': 0.2916666666666667,
 'tau_10': 0.7777777777777778,
 'tau_13': 0.7777777777777778,
 'tau_14': 0.7777777777777778,
 'tau_23': 0.8666666666666667,
 'acc_23': 0.9333333333333333}

>>> compute_kendall_taus(h, m2)
{'C': 9,
 'D': 0,
 'Th': 6,
 'Tm': 0,
 'Thm': 0,
 'tau_a': 0.6,
 'tau_b': 0.7745966692414834,
 'tau_c': 0.375,
 'tau_10': 1.0,
 'tau_13': 1.0,
 'tau_14': 1.0,
 'tau_23': 0.2,
 'acc_23': 0.6}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment