Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active February 7, 2021 18:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/661adb1703ba409a5630577155117946 to your computer and use it in GitHub Desktop.
Save ahwillia/661adb1703ba409a5630577155117946 to your computer and use it in GitHub Desktop.
Two-sample permutation test in Python
"""
A simple implementation of a permutation test among two
independent samples.
"""
import numpy as np
from sklearn.utils.validation import check_random_state
from more_itertools import distinct_permutations
from scipy.stats import percentileofscore
from math import factorial
def permtest(
x, y, statistic="mean", max_samples=100000,
random_state=None):
"""
Conducts a permutation test between two independent
samples.
Parameters
----------
x : ndarray
First set of datapoints.
y : ndarray
Second set of datapoints.
statistic : str or callable, optional
Function that takes in samples x and y and reports
statistic of interest. By default, "mean" reports
the difference of sample means. List of default
options are ("mean", "median").
max_samples : int, optional.
Maximum number of label permutations to try.
random_state : np.random.RandomState, int, or None.
If specified, used to seed the random number
generator to shuffle the ordering of the
datapoints.
"""
# initialize random state and function of interest.
rs = check_random_state(random_state)
stat_func = _get_stat_func(statistic)
# Concatenate samples in random order.
xy = np.concatenate((x, y))
rs.shuffle(xy)
# Create data labels (True if sample is in "x" and False if sample
# is in "y"), and randomly shuffle before generating permutations.
labels = np.zeros_like(xy, dtype="bool")
labels[:x.size] = True
rs.shuffle(labels)
# Number of distinct permutations.
n_perms = factorial(xy.size) // factorial(x.size) // factorial(y.size)
# Allocate space for computed statistics.
shuffled_stats = np.full(min(max_samples, n_perms), np.nan)
# Print coverage.
print("Computing {0:} / {1:.2e} ({2:2.2f}%) of label permutations: ".format(
shuffled_stats.size, n_perms, 100 * shuffled_stats.size / n_perms
))
# Iterate over distinct permutations if we have sufficient coverage.
# Otherwise, yield random permutations
if (max_samples / n_perms) > 0.5:
print("Iterating over distinct permutations...")
itr = distinct_permutations(labels)
else:
print("Sampling random permutations...")
itr = _randperms(rs, labels)
# Iterate over distinct permutations.
for i, perm in enumerate(itr):
# End early.
if i >= shuffled_stats.size:
break
# Create shuffled stand-ins for x and y.
x_ = xy[np.asarray(perm)]
y_ = xy[~np.asarray(perm)]
# Compute statistic.
shuffled_stats[i] = stat_func(x_, y_)
# Compute a two-sided p-value. We take the smallest
# percentile and then multiply by two.
pval = 2 * 0.01 * min(
percentileofscore(shuffled_stats, stat_func(x, y)),
percentileofscore(shuffled_stats, stat_func(y, x))
)
return pval
def _get_stat_func(name_or_func):
"""
Instantiates functions that compute default statistics of
interest.
"""
# If specified function is callable
if not isinstance(name_or_func, str):
if callable(name_or_func):
return name_or_func
else:
raise ValueError(
"`statistic` should be a string like ('mean', 'median')"
" or a function that takes in samples x, y and returns"
" the statistic of interest."
)
# Default functions.
if name_or_func == "mean":
return lambda x, y: np.mean(x) - np.mean(y)
elif name_or_func == "median":
return lambda x, y: np.median(x) - np.median(y)
else:
raise ValueError(
"Did not recognize statistic."
)
def _randperms(rs, labels):
perm = labels.copy()
while True:
rs.shuffle(perm)
yield perm
if __name__ == "__main__":
np.random.seed(123)
x = np.random.randn(100)
y = .5 + np.random.randn(100)
print(
f"p = {permtest(x, y, random_state=None)}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment