Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Created November 13, 2022 15:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/d32110db5727410107646c1b4db31001 to your computer and use it in GitHub Desktop.
Save ahwillia/d32110db5727410107646c1b4db31001 to your computer and use it in GitHub Desktop.
A kernel two-sample test for equality of distributions (Gretton et al. 2012)
import numpy as np
from scipy.spatial.distance import cdist, pdist
def mmd_two_sample_test(X, Y):
"""
Implements Gretton's test for equality of
distributions in high-dimensional settings
using concentration bounds on the maximum
mean discrepancy (MMD). This function uses
the unbiased estimator of the MMD (see
Lemma 6, Gretton et al., 2012) and upper
bounds the p-value using a Hoeffding
large-deviation bound (see Theorem 10,
Gretton et al., 2012).
The test considers two sets of observed
datapoints, X and Y, which are assumed to
be drawn i.i.d. from underlying probability
distributions P and Q. The null hypothesis
is that P = Q.
Note that this function assumes that the number
of samples from each distribution are equal.
Reference
---------
Gretton et al. (2012). A Kernel Two-Sample Test.
Journal of Machine Learning Research 13: 723-773.
Parameters
----------
X : ndarray (num_samples x num_features)
First set of observed samples, assumed to be
drawn from some unknown distribution P.
Y : ndarray (num_samples x num_features)
Second set of observed samples, assumed to be
drawn from some unknown distribution Q.
Returns
-------
pvalue : float
An upper bound on the probability of observing
an MMD distance greater than or equal to the
observed value, assuming that the null hypothesis
(i.e. that P = Q) is true.
"""
assert X.shape == Y.shape
m = X.shape[0]
# Compute pairwise distances
xd = pdist(X, metric="euclidean")
yd = pdist(Y, metric="euclidean")
xyd = cdist(X, Y, metric="euclidean").ravel()
# Set kernel bandwidth (Gretton et al. suggest to use
# the median distance).
sigma_sq = np.median(
np.concatenate((xd, yd, xyd))
) ** 2
# Compute unbiased MMD distance.
kxx = np.mean(np.exp(-(xd**2) / (2 * sigma_sq)))
kyy = np.mean(np.exp(-(yd**2) / (2 * sigma_sq)))
kxy = np.mean(np.exp(-(xyd**2) / (2 * sigma_sq)))
mmd_obs = kxx + kyy - 2 * kxy
# Apply theorem 10 to compute the p-value.
if mmd_obs < 0:
return 1.0
else:
return np.exp(
-((mmd_obs ** 2) * (m // 2)) / 8
)
if __name__ == "__main__":
# TEST THAT WE FAIL TO REJECT THE NULL
d = 10
num_samples = 1000
pvals = np.empty(100)
for seed in range(pvals.size):
# Draw random samples from equal distributions.
rs = np.random.RandomState(seed)
X = rs.randn(num_samples, d)
Y = rs.randn(num_samples, d)
pvals[seed] = mmd_two_sample_test(X, Y)
print("FIRST TEST -- NULL HYPOTHESIS TRUE")
print(f"{np.sum(pvals < 0.05)} / {pvals.size} tests reject the null.")
# TEST THAT WE REJECT THE NULL
for seed in range(pvals.size):
# Draw random samples from equal distributions.
rs = np.random.RandomState(seed)
X = rs.randn(num_samples, d)
Y = rs.randn(num_samples, d) + 1
pvals[seed] = mmd_two_sample_test(X, Y)
print("SECOND TEST -- NULL HYPOTHESIS FALSE")
print(f"{np.sum(pvals < 0.05)} / {pvals.size} tests reject the null.")
@ahwillia
Copy link
Author

ahwillia commented Nov 13, 2022

Link to paper: https://www.jmlr.org/papers/v13/gretton12a.html

Output should be:

FIRST TEST -- NULL HYPOTHESIS TRUE
0 / 100 tests reject the null.
SECOND TEST -- NULL HYPOTHESIS FALSE
100 / 100 tests reject the null.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment