Skip to content

Instantly share code, notes, and snippets.

@jamestwebber
Last active August 21, 2023 16:39
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jamestwebber/38ab26d281f97feb8196b3d93edeeb7b to your computer and use it in GitHub Desktop.
Save jamestwebber/38ab26d281f97feb8196b3d93edeeb7b to your computer and use it in GitHub Desktop.
numba-fied, parallelized Mann-Whitney U-test for 2d arrays
import numpy as np
import numba as nb
import scipy.stats
@nb.njit(parallel=True)
def tiecorrect(rankvals):
"""
parallelized version of scipy.stats.tiecorrect
"""
tc = np.ones(rankvals.shape[1], dtype=np.float64)
for j in nb.prange(rankvals.shape[1]):
arr = np.sort(np.ravel(rankvals[:,j]))
idx = np.nonzero(
np.concatenate(
(
np.array([True]),
arr[1:] != arr[:-1],
np.array([True])
)
)
)[0]
cnt = np.diff(idx).astype(np.float64)
size = np.float64(arr.size)
if size >= 2:
tc[j] = 1.0 - (cnt**3 - cnt).sum() / (size**3 - size)
return tc
@nb.njit(parallel=True)
def rankdata(data):
"""
parallelized version of scipy.stats.rankdata
"""
ranked = np.empty(data.shape, dtype=np.float64)
for j in nb.prange(data.shape[1]):
arr = np.ravel(data[:, j])
sorter = np.argsort(arr)
arr = arr[sorter]
obs = np.concatenate((np.array([True]), arr[1:] != arr[:-1]))
dense = np.empty(obs.size, dtype=np.int64)
dense[sorter] = obs.cumsum()
# cumulative counts of each unique value
count = np.concatenate((np.nonzero(obs)[0], np.array([len(obs)])))
ranked[:, j] = 0.5 * (count[dense] + count[dense - 1] + 1)
return ranked
def mannwhitneyu(x, y, use_continuity=True):
"""Version of Mann-Whitney U-test that runs in parallel on 2d arrays
this is the two-sided test, asymptotic algo only
"""
x = np.asarray(x)
y = np.asarray(y)
assert x.shape[1] == y.shape[1]
n1 = x.shape[0]
n2 = y.shape[0]
ranked = rankdata(np.concatenate((x, y)))
rankx = ranked[0:n1, :] # get the x-ranks
u1 = n1 * n2 + (n1 * (n1 + 1)) / 2.0 - np.sum(rankx, axis=0) # calc U for x
u2 = n1 * n2 - u1 # remainder is U for y
T = tiecorrect(ranked)
# if *everything* is identical we'll raise an error, not otherwise
if np.all(T == 0):
raise ValueError('All numbers are identical in mannwhitneyu')
sd = np.sqrt(T * n1 * n2 * (n1 + n2 + 1) / 12.0)
meanrank = n1 * n2 / 2.0 + 0.5 * use_continuity
bigu = np.maximum(u1, u2)
with np.errstate(divide='ignore', invalid='ignore'):
z = (bigu - meanrank) / sd
p = np.clip(2 * scipy.stats.norm.sf(z), 0, 1)
return u2, p
@jamestwebber
Copy link
Author

Should note: this code was adapted from the scipy codebase, here and here. I found that the specialized tie factor in the Mann-Whitney module was actually much slower than this version.

License is same as scipy, BSD3, see here

@jamestwebber
Copy link
Author

Small tweak, because u1 and u2 are always the same shape it's better to use np.maximum(u1, u2) rather than np.max(np.vstack((u1, u2)), 0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment