Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save djsutherland/4fb1122179cef52493db8ce1a3ce96ed to your computer and use it in GitHub Desktop.
Save djsutherland/4fb1122179cef52493db8ce1a3ce96ed to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a re-implementation of the experiments from [Kernel Two-Sample Hypothesis Testing Using Kernel Set Classification](https://arxiv.org/abs/1706.05612)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.metrics.pairwise import euclidean_distances\n",
"\n",
"def median_heuristic(X, Y, median_samples=1000):\n",
" sub = lambda feats, n: feats[np.random.choice(\n",
" feats.shape[0], min(feats.shape[0], n), replace=False)]\n",
" Z = np.r_[sub(X, median_samples // 2), sub(Y, median_samples // 2)]\n",
" D2 = euclidean_distances(Z, squared=True)\n",
" upper = D2[np.triu_indices_from(D2, k=1)]\n",
" kernel_width = np.median(upper, overwrite_input=True)\n",
" bandwidth = np.sqrt(kernel_width / 2)\n",
" return bandwidth"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Problem from the paper: $N(0, I)$ versus $N(0, v I)$."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def gen_data(d, n=20, m=100, var=1.01):\n",
" X = np.random.randn(n, d)\n",
" Y = np.random.randn(m, d) * var\n",
" return X, Y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MMD version"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, here's [some existing wrapper code](https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd_test.py) to compute MMD tests with Shogun's fast permutations for getting thresholds."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import modshogun as sg\n",
"\n",
"import multiprocessing as mp\n",
"num_threads = mp.cpu_count()\n",
"sg.get_global_parallel().set_num_threads(num_threads)\n",
"\n",
"\n",
"def rbf_mmd_test(X, Y, bandwidth='median', null_samples=1001,\n",
" median_samples=1000, cache_size=32):\n",
" '''\n",
" Run an MMD test using a Gaussian kernel.\n",
" Parameters\n",
" ----------\n",
" X : row-instance feature array\n",
" Y : row-instance feature array\n",
" bandwidth : float or 'median'\n",
" The bandwidth of the RBF kernel (sigma).\n",
" If 'median', estimates the median pairwise distance in the\n",
" aggregate sample and uses that.\n",
" null_samples : int\n",
" How many times to sample from the null distribution.\n",
" median_samples : int\n",
" How many points to use for estimating the bandwidth.\n",
" Returns\n",
" -------\n",
" p_val : float\n",
" The obtained p value of the test.\n",
" stat : float\n",
" The test statistic.\n",
" null_samples : array of length null_samples\n",
" The samples from the null distribution.\n",
" bandwidth : float\n",
" The used kernel bandwidth\n",
" '''\n",
"\n",
" if bandwidth == 'median':\n",
" bandwidth = median_heuristic(X, Y, median_samples=median_samples)\n",
" kernel_width = 2 * bandwidth**2\n",
"\n",
" mmd = sg.QuadraticTimeMMD()\n",
" mmd.set_p(sg.RealFeatures(X.T.astype(np.float64)))\n",
" mmd.set_q(sg.RealFeatures(Y.T.astype(np.float64)))\n",
" mmd.set_kernel(sg.GaussianKernel(cache_size, kernel_width))\n",
"\n",
" mmd.set_num_null_samples(null_samples)\n",
" samps = mmd.sample_null()\n",
" stat = mmd.compute_statistic()\n",
"\n",
" p_val = np.mean(stat <= samps)\n",
" return p_val, stat, samps, bandwidth"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Type 2 error: don't reject when null is false."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 49.1 s, sys: 37.8 s, total: 1min 26s\n",
"Wall time: 3.87 s\n"
]
}
],
"source": [
"%%time\n",
"mmd_ps = np.array([\n",
" rbf_mmd_test(*gen_data(5))[0] for _ in range(100)])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.92000000000000004"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(mmd_ps > .05).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Type 1 error: reject when null is true."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 47.1 s, sys: 35.5 s, total: 1min 22s\n",
"Wall time: 3.65 s\n"
]
}
],
"source": [
"%%time\n",
"mmd_ps_null = np.array([\n",
" rbf_mmd_test(*gen_data(5, var=1))[0] for _ in range(100)])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.070000000000000007"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(mmd_ps_null <= .05).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This test is basically not doing anything, as claimed in the paper."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kernel set classification method"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.metrics.pairwise import rbf_kernel\n",
"from sklearn.svm import OneClassSVM"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def set_classification_test(X, Y, bw=1, nu=0.1, sub_size=7, n_subs=100):\n",
" gamma = 1 / (2 * bw**2)\n",
" KX = rbf_kernel(X, gamma=gamma)\n",
" KXY = rbf_kernel(X, Y, gamma=gamma)\n",
" \n",
" n = X.shape[0]\n",
" subs = np.vstack([\n",
" np.random.choice(n, sub_size, replace=False) for _ in range(n_subs)])\n",
" \n",
" K_subs = np.array([\n",
" [KX[np.ix_(si, sj)].mean() for sj in subs]\n",
" for si in subs])\n",
" K_test = np.array([\n",
" KXY[sub, :].mean() for sub in subs])[np.newaxis]\n",
" \n",
" svm = OneClassSVM(kernel='precomputed', nu=nu)\n",
" svm.fit(K_subs)\n",
" return svm.predict(K_test)[0] > 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Type 2 error: don't reject when null is false."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 44.8 s, sys: 414 ms, total: 45.2 s\n",
"Wall time: 44.7 s\n"
]
}
],
"source": [
"%%time\n",
"svm_rejs = np.array([\n",
" set_classification_test(*gen_data(5)) for _ in range(100)])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(~svm_rejs).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Type 1 error: reject when null is true."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 44.5 s, sys: 539 ms, total: 45.1 s\n",
"Wall time: 44.5 s\n"
]
}
],
"source": [
"%%time\n",
"svm_rejs_null = np.array([\n",
" set_classification_test(*gen_data(5, var=1)) for _ in range(100)])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"svm_rejs_null.mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So this test _never_ rejected the null, unlike the results from the paper."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Theoretical best possible test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we know the exact generative model, the best test would compute the probability that $X$ is from either of the two distributions, $N(0, I)$ versus $N(0, v I)$; the same for $Y$, and then look at the probability that they match up (since they're indpendent)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The likelihood of $N(0, v I)$ is\n",
"$$\n",
"\\exp\\left( -\\frac{d}{2} \\log (2 \\pi v) - \\frac{1}{2 v} \\lVert x \\rVert^2 \\right)\n",
".$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"so, for a sample set $X$, the probability that it came from $N(0, I)$ as opposed to $N(0, v I)$ is\n",
"\\begin{align}\n",
"p_X\n",
"&=\n",
"\\frac{\n",
"\\exp\\left( -\\frac{n d}{2} \\log (2 \\pi) - \\sum_{i=1}^n \\frac{1}{2} \\lVert X_i \\rVert^2 \\right)\n",
"}{\n",
"\\exp\\left( -\\frac{n d}{2} \\log (2 \\pi) - \\sum_{i=1}^n \\frac{1}{2} \\lVert X_i \\rVert^2 \\right)\n",
"+ \\exp\\left( -\\frac{n d}{2} \\log (2 \\pi) - \\frac{n d}{2} \\log v - \\sum_{i=1}^n \\frac{1}{2 v} \\lVert X_i \\rVert^2 \\right)\n",
"}\n",
"\\\\ &=\n",
"\\frac{\n",
"\\exp\\left( - \\frac12 \\sum_{i=1}^n \\lVert X_i \\rVert^2 \\right)\n",
"}{\n",
"\\exp\\left( - \\frac12 \\sum_{i=1}^n \\lVert X_i \\rVert^2 \\right)\n",
"+ \\exp\\left( - \\frac{n d}{2} \\log v - \\frac1{2 v} \\sum_{i=1}^n \\lVert X_i \\rVert^2 \\right)\n",
"}\n",
"\\end{align}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So $p_X$ is the probability that sample $X$ came from $N(0, 1)$, $p_Y$ the probability that sample $Y$ did. Thus the probability that the two came from the _same_ distribution (either of the two) is $p_X p_Y + (1 - p_X) (1 - p_Y)$."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try that in the same way:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def prob_1(X, v=1.01):\n",
" n, d = X.shape\n",
" s = .5 * (X ** 2).sum()\n",
" l1 = np.exp(-s)\n",
" l2 = np.exp(-n * d / 2 * np.log(v) - s / v)\n",
" return l1 / (l1 + l2)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def prob_same(X, Y, v=1.01):\n",
" pX = prob_1(X, v)\n",
" pY = prob_1(Y, v)\n",
" return pX * pY + (1 - pX) * (1 - pY)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"probs = np.array([\n",
" prob_same(*gen_data(5)) for _ in range(100)])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.50022106675412803, 0.49573053015762769, 0.5102717879843921)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(probs), np.min(probs), np.max(probs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So the distributions are basically theoretically indistinguishable at this sample size, even with $v = 1.01$ instead of $1 + 10^{-21}$."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"probs_null = np.array([\n",
" prob_same(*gen_data(5, 1)) for _ in range(100)])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.50002290599779586, 0.49912412130653572, 0.50083149225372237)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(probs_null), np.min(probs_null), np.max(probs_null)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They're also indistinguishable when they're the same distributions, unsurprisingly."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:shogun]",
"language": "python",
"name": "conda-env-shogun-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment