Skip to content

Instantly share code, notes, and snippets.

@ili3p
Created December 28, 2021 09:13
Show Gist options
  • Save ili3p/f2b38b898f6eab0d87ec248ea39fde94 to your computer and use it in GitHub Desktop.
Save ili3p/f2b38b898f6eab0d87ec248ea39fde94 to your computer and use it in GitHub Desktop.
Fast Kendall Tau calculation with pytorch.
import torch
import time
from scipy.stats import kendalltau
def kendall(x, y):
n = x.shape[0]
def sub_pairs(x):
return x.expand(n,n).T.sub(x).sign_()
return sub_pairs(x).mul_(sub_pairs(y)).sum().div(n*(n-1))
d = torch.empty(10)
for i in range(10):
x, y = torch.randperm(4000), torch.randperm(4000)
t = time.time_ns()
m = kendall(x,y)
d[i] = (time.time_ns() - t)*1e-6
print(f'{d[i]:.2f}ms')
print(f'{abs(kendalltau(x,y).correlation - m):.9f}')
print(f'AVG {d.mean():.2f}ms')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment