Created
August 29, 2015 08:44
-
-
Save lucastheis/9e5a85fb08597ebbbe99 to your computer and use it in GitHub Desktop.
Bias of standard error in leave-one-out cross-validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:0d47d4faa15ccfec3fccf4b8e613f403bcd30f2ce7ddee88b7ea4620dbf609c2" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%pylab inline" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"Populating the interactive namespace from numpy and matplotlib\n" | |
] | |
} | |
], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from numpy.random import multinomial, dirichlet" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def sample(p, N):\n", | |
" \"\"\"\n", | |
" Sample from histogram.\n", | |
" \"\"\"\n", | |
" return multinomial(1, p, size=N)\n", | |
"\n", | |
"def estimate(X, r=1.):\n", | |
" \"\"\"\n", | |
" Estimate histogram.\n", | |
" \"\"\"\n", | |
" return (sum(X, 0) + r) / (X.shape[0] + r * X.shape[1])\n", | |
"\n", | |
"def evaluate(X, q):\n", | |
" \"\"\"\n", | |
" Evaluate log-likelihood of histogram.\n", | |
" \"\"\"\n", | |
" return -sum(X * log(q))" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 73 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# dataset size\n", | |
"N = 100\n", | |
"\n", | |
"# dimensionality\n", | |
"D = 100\n", | |
"\n", | |
"# true distribution\n", | |
"a = 10.\n", | |
"p = dirichlet([a / D] * D)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 153 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def cross_validate(X):\n", | |
" # estimate performance via cross-validation\n", | |
" loss = []\n", | |
" for n in range(N):\n", | |
" q = estimate(delete(X, n, 0))\n", | |
" loss.append(evaluate(X[n], q))\n", | |
" \n", | |
" return mean(loss), std(loss, ddof=1)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 150 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# number of repetitions\n", | |
"R = 10000\n", | |
"\n", | |
"loss_avg = []\n", | |
"loss_std = []\n", | |
"\n", | |
"for r in range(R):\n", | |
" L, s = cross_validate(sample(p, N))\n", | |
" loss_avg.append(L)\n", | |
" loss_std.append(s)\n", | |
"\n", | |
"# average cross-validation estimates\n", | |
"print mean(loss_avg)\n", | |
"print mean(loss_std)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"3.56419933227\n", | |
"0.713336314565\n" | |
] | |
} | |
], | |
"prompt_number": 154 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# number of samples\n", | |
"R = 10000\n", | |
"\n", | |
"# generate R independent samples of the loss\n", | |
"logq = []\n", | |
"for r in range(R):\n", | |
" logq.append(log(estimate(sample(p, N - 1))))\n", | |
" \n", | |
"loss_avg = -sum(p * mean(logq, 0))\n", | |
"loss_std = sqrt(sum(p * mean(square(logq), 0)) - loss_avg**2)\n", | |
"\n", | |
"# proper estimate\n", | |
"print loss_avg\n", | |
"print loss_std" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"3.5636444913\n", | |
"0.718542014501\n" | |
] | |
} | |
], | |
"prompt_number": 155 | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment