Skip to content

Instantly share code, notes, and snippets.

@poldrack
Created December 6, 2012 02:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save poldrack/4221361 to your computer and use it in GitHub Desktop.
Save poldrack/4221361 to your computer and use it in GitHub Desktop.
Crossvalidation for regression
{
"metadata": {
"name": "regressioncv"
},
"nbformat": 2,
"worksheets": [
{
"cells": [
{
"cell_type": "markdown",
"source": [
"This notebook lays out some simulations that examine the effects of various crossvalidation schemes in the context",
"of regression models. The code for these simulations is available at https://github.com/poldrack/regressioncv - in ",
"particular, this example is based on regressioncv_2d.py",
"",
"First, we import the necessary functions:"
]
},
{
"cell_type": "code",
"collapsed": true,
"input": [
"import numpy as N",
"from sklearn.linear_model import LinearRegression",
"from sklearn import cross_validation",
"from statsmodels.regression.linear_model import OLS"
],
"language": "python",
"outputs": [],
"prompt_number": 7
},
{
"cell_type": "markdown",
"source": [
"Set up some variables that we will need."
]
},
{
"cell_type": "code",
"collapsed": true,
"input": [
"nsubs=48",
"nruns=100"
],
"language": "python",
"outputs": [],
"prompt_number": 8
},
{
"cell_type": "markdown",
"source": [
"Define a function to perform balanced crossvalidation - this function performs an ANOVA on the X ",
"variable across crossvalidation folds until it finds a set with a p-value greater than the",
"specified threshold. In practice this should result in folds with relatively similar distributions",
"if the p-value is high, though it might make sense to try using some other nonparametric techniques",
"in order to do a better job of also matching the range of the X values."
]
},
{
"cell_type": "code",
"collapsed": true,
"input": [
"def get_sample_balcv(x,y,nfolds,pthresh=0.9):",
" \"\"\"",
" This function uses anova across CV folds to find",
" a set of folds that are balanced in their distriutions",
" of the X and Y values - see Kohavi, 1995",
" \"\"\"",
"",
" nsubs=len(x)",
"",
" # cycle through until we find a split that is good enough",
" ",
" good_split=0",
" while good_split==0:",
" cv=cross_validation.KFold(n=nsubs,n_folds=nfolds,shuffle=True)",
" ctr=0",
" idx=N.zeros((nsubs,nfolds)) # this is the design matrix",
" for train,test in cv:",
" idx[test,ctr]=1",
" ctr+=1",
"",
" lm_x=OLS(x-N.mean(x),idx).fit()",
" lm_y=OLS(y-N.mean(y),idx).fit()",
" ",
" if lm_x.f_pvalue>pthresh and lm_y.f_pvalue>pthresh :",
" good_split=1",
"",
" # do some reshaping needed for the sklearn linear regression function",
" x=x.reshape((nsubs,1))",
" y=y.reshape((nsubs,1))",
" ",
" pred=N.zeros((nsubs,1))",
" ",
" for train,test in cv:",
" lr=LinearRegression()",
" lr.fit(x[train,:],y[train,:])",
" pred[test]=lr.predict(x[test])",
"",
" return N.corrcoef(pred[:,0],y[:,0])[0,1]"
],
"language": "python",
"outputs": [],
"prompt_number": 9
},
{
"cell_type": "markdown",
"source": [
"Set up a variable to store all the results"
]
},
{
"cell_type": "code",
"collapsed": true,
"input": [
"corrs={'splithalf':N.zeros(nruns),'loo':N.zeros(nruns),'balcv_lo':N.zeros(nruns),'balcv_hi':N.zeros(nruns)}"
],
"language": "python",
"outputs": [],
"prompt_number": 10
},
{
"cell_type": "markdown",
"source": [
"This is the meat of the simulation - we loop through and perform each crossvalidation type on each simulated dataset."
]
},
{
"cell_type": "code",
"collapsed": true,
"input": [
"for run in range(nruns):",
" # create the X distribution - this will be doubled over",
" x=N.random.rand(nsubs/2).reshape((nsubs/2,1))",
"",
" # create two completely random Y samples",
" y=N.random.rand(len(x)).reshape(x.shape)",
" y2=N.random.rand(len(x)).reshape(x.shape)",
" ",
" pred_y=N.zeros(x.shape)",
" pred_y2=N.zeros(x.shape)",
"",
" # compute out-of-sample predictions for each Y dataset",
" # i.e balanced split-half crossvalidation ",
" lr=LinearRegression()",
" lr.fit(x,y)",
" pred_y2=lr.predict(x)",
" lr.fit(x,y2)",
" pred_y=lr.predict(x)",
"",
" pred_splithalf=N.vstack((pred_y,pred_y2))",
" y_all=N.vstack((y,y2))",
" corrs['splithalf'][run]=N.corrcoef(pred_splithalf[:,0],y_all[:,0])[0,1]",
"",
"",
"## # this is an alternate way to implement split-half, removed for now",
"## # after confirming that it gives identical results to the one above",
"## split=cross_validation.KFold(n=len(x_all),n_folds=2,shuffle=False)",
"## pred_splitcv=N.zeros(x_all.shape)",
"## for train,test in split:",
"## lr=LinearRegression()",
"## lr.fit(x_all[train,:],y_all[train,:])",
"## pred_splitcv[test]=lr.predict(x_all[test])",
"## corrs['splitcv'][run]=N.corrcoef(pred_splitcv[:,0],y_all[:,0])[0,1]",
"",
" # this stacked version is needed for loo",
" ",
" x_all=N.vstack((x,x))",
"",
" # do leave-one-out CV using sklearn tools",
" loo=cross_validation.LeaveOneOut(nsubs)",
" pred_loo=N.zeros(x_all.shape)",
" for train,test in loo:",
" lr=LinearRegression()",
" lr.fit(x_all[train,:],y_all[train,:])",
" pred_loo[test]=lr.predict(x_all[test])",
" corrs['loo'][run]=N.corrcoef(pred_loo[:,0],y_all[:,0])[0,1]",
"",
" # get results with balanced CV for high and low threshold",
" corrs['balcv_lo'][run]=get_sample_balcv(x_all,y_all,8,pthresh=0.001)",
" corrs['balcv_hi'][run]=get_sample_balcv(x_all,y_all,8,pthresh=0.99)"
],
"language": "python",
"outputs": [],
"prompt_number": 11
},
{
"cell_type": "markdown",
"source": [
"Now print out the correlation between predicted and true values for each of the crossvalidation types. Remember that X ",
"and Y values were created purely by random, so in principle there should be no correlation between predicted and actual ",
"values."
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"print 'Mean correlation (predicted,true):'",
"for k in corrs.iterkeys():",
" print k,N.mean(corrs[k])"
],
"language": "python",
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Mean correlation (predicted,true):",
"splithalf -0.0834828473165",
"balcv_hi -0.167181054783",
"loo -0.370596457262",
"balcv_lo -0.241396661229"
]
}
],
"prompt_number": 12
},
{
"cell_type": "markdown",
"source": [
"What you should see is that all of the crossvalidation types have negative correlations between predicted",
"and actual values. This should be weakest for the split half CV, strongest for the leave-one-out, and intermediate",
"for the balanced CV (balcv_hi and balcv_lo use high and low thresholds for the anova, respectively)."
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment