Skip to content

Instantly share code, notes, and snippets.

@lucastheis
Created August 29, 2015 08:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lucastheis/9e5a85fb08597ebbbe99 to your computer and use it in GitHub Desktop.
Save lucastheis/9e5a85fb08597ebbbe99 to your computer and use it in GitHub Desktop.
Bias of standard error in leave-one-out cross-validation
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:0d47d4faa15ccfec3fccf4b8e613f403bcd30f2ce7ddee88b7ea4620dbf609c2"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"%pylab inline"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from numpy.random import multinomial, dirichlet"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def sample(p, N):\n",
" \"\"\"\n",
" Sample from histogram.\n",
" \"\"\"\n",
" return multinomial(1, p, size=N)\n",
"\n",
"def estimate(X, r=1.):\n",
" \"\"\"\n",
" Estimate histogram.\n",
" \"\"\"\n",
" return (sum(X, 0) + r) / (X.shape[0] + r * X.shape[1])\n",
"\n",
"def evaluate(X, q):\n",
" \"\"\"\n",
" Evaluate log-likelihood of histogram.\n",
" \"\"\"\n",
" return -sum(X * log(q))"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 73
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# dataset size\n",
"N = 100\n",
"\n",
"# dimensionality\n",
"D = 100\n",
"\n",
"# true distribution\n",
"a = 10.\n",
"p = dirichlet([a / D] * D)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 153
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def cross_validate(X):\n",
" # estimate performance via cross-validation\n",
" loss = []\n",
" for n in range(N):\n",
" q = estimate(delete(X, n, 0))\n",
" loss.append(evaluate(X[n], q))\n",
" \n",
" return mean(loss), std(loss, ddof=1)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 150
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# number of repetitions\n",
"R = 10000\n",
"\n",
"loss_avg = []\n",
"loss_std = []\n",
"\n",
"for r in range(R):\n",
" L, s = cross_validate(sample(p, N))\n",
" loss_avg.append(L)\n",
" loss_std.append(s)\n",
"\n",
"# average cross-validation estimates\n",
"print mean(loss_avg)\n",
"print mean(loss_std)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"3.56419933227\n",
"0.713336314565\n"
]
}
],
"prompt_number": 154
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# number of samples\n",
"R = 10000\n",
"\n",
"# generate R independent samples of the loss\n",
"logq = []\n",
"for r in range(R):\n",
" logq.append(log(estimate(sample(p, N - 1))))\n",
" \n",
"loss_avg = -sum(p * mean(logq, 0))\n",
"loss_std = sqrt(sum(p * mean(square(logq), 0)) - loss_avg**2)\n",
"\n",
"# proper estimate\n",
"print loss_avg\n",
"print loss_std"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"3.5636444913\n",
"0.718542014501\n"
]
}
],
"prompt_number": 155
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment