Skip to content

Instantly share code, notes, and snippets.

@jonathan-taylor
Last active December 16, 2015 06:00
Show Gist options
  • Save jonathan-taylor/04cc0306a08fb15c4c30 to your computer and use it in GitHub Desktop.
Save jonathan-taylor/04cc0306a08fb15c4c30 to your computer and use it in GitHub Desktop.
Selective inference for logistic regression
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jonathantaylor/anaconda/lib/python2.7/site-packages/rpy2/robjects/functions.py:106: UserWarning: Loading required package: Matrix\n",
"\n",
" res = super(Function, self).__call__(*new_args, **new_kwargs)\n",
"/Users/jonathantaylor/anaconda/lib/python2.7/site-packages/rpy2/robjects/functions.py:106: UserWarning: Loading required package: foreach\n",
"\n",
" res = super(Function, self).__call__(*new_args, **new_kwargs)\n",
"/Users/jonathantaylor/anaconda/lib/python2.7/site-packages/rpy2/robjects/functions.py:106: UserWarning: foreach: simple, scalable parallel programming from Revolution Analytics\n",
"Use Revolution R for scalability, fault tolerance and more.\n",
"http://www.revolutionanalytics.com\n",
"\n",
" res = super(Function, self).__call__(*new_args, **new_kwargs)\n",
"/Users/jonathantaylor/anaconda/lib/python2.7/site-packages/rpy2/robjects/functions.py:106: UserWarning: Loaded glmnet 2.0-2\n",
"\n",
"\n",
" res = super(Function, self).__call__(*new_args, **new_kwargs)\n",
"/Users/jonathantaylor/anaconda/lib/python2.7/site-packages/selection/algorithms/lasso.py:32: UserWarning: cvx not available\n",
" warnings.warn('cvx not available')\n"
]
}
],
"source": [
"import numpy as np\n",
"%load_ext rpy2.ipython\n",
"%R library(glmnet)\n",
"from selection.algorithms.logistic import instance\n",
"import regreg.api as rr"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(100, 200)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X, Y, beta, active = instance()\n",
"n, p = X.shape\n",
"n, p"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Analog to active KKT conditions\n",
"\n",
"Logistic regression with LASSO penalty is\n",
"$$\n",
"\\text{minimize}_{\\beta} \\frac{1}{n} \\ell(\\beta) + \\lambda \\|\\beta\\|_1\n",
"$$\n",
"where $\\ell$ is negative-log of the logistic likelihood.\n",
"\n",
"It should not be too hard to convince ourselves that the\n",
"analog of the active block of the KKT conditions for logistic regression with LASSO penalty is\n",
"$$\n",
"\\text{sign}\\left(\\bar{\\beta}_E - \\lambda Q_E(\\bar{\\beta}_E)^{-1}z_E\\right) = z_E\n",
"$$\n",
"where $\\bar{\\beta}_E$ are the unpenalized logistic estimators for that active set and\n",
"$$\n",
"\\begin{aligned}\n",
"Q_E(\\beta_E) &= \\frac{1}{n} X_E^TW_E(\\beta_E;X)X_E \\\\\n",
"W_E(\\beta_E) &= \\text{diag}(\\pi_E(\\beta_E;X)(1 - \\pi_E(\\beta_E;X))) \\\\\n",
"\\pi_E(\\beta_E;X) &= \\frac{\\exp(X_E\\beta_E)}{1 + \\exp(X_E\\beta_E)}\n",
"\\end{aligned}\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def simulate(n=100, p=50, rho=0.3, s=5, snr=5, lam_frac=0.8, \n",
" choice='theoretical',\n",
" use_glmnet=True, \n",
" approx='unpenalized'):\n",
" X, Y, beta, true_active = instance(n=n, p=p, rho=rho, s=s, snr=snr)\n",
" \n",
" n, p = X.shape\n",
" \n",
" # add intercept\n",
" X1 = np.hstack([np.ones((n,1)), X])\n",
" X -= X.mean(0)[None,:]\n",
"\n",
" if choice != 'theoretical':\n",
" lam = lam_frac * np.fabs(X.T.dot(Y)).max() / n\n",
" else:\n",
" Z = np.random.binomial(1, 0.5, size=(n, 5000))\n",
" lam = np.fabs(X.T.dot(Z)).max(0) / n\n",
" lam = lam.mean()\n",
" \n",
" # solve the penalized problem\n",
" \n",
" %R -i X,Y,lam\n",
" %R Y = as.numeric(Y)\n",
" %R G = glmnet(X, Y, family='binomial', standardize=FALSE)\n",
" %R soln = as.numeric(coef(G, lam, exact=TRUE))\n",
" soln_glmnet = %R soln\n",
"\n",
" if use_glmnet:\n",
" soln = soln_glmnet\n",
" else: # use regreg\n",
" loss = rr.logistic_loss(X1, Y.copy(), coef=0.5)\n",
" weights = np.ones(p+1)\n",
" weights[0] = 0.\n",
" penalty = rr.weighted_l1norm(weights, lagrange=lam)\n",
" problem = rr.simple_problem(loss, penalty)\n",
" soln = problem.solve(min_its=200, tol=1.e-16)\n",
" \n",
" active = (soln != 0)\n",
" active_signs = np.sign(soln[active])\n",
" active_signs[0] = 0.\n",
" \n",
" # solve the unpenalized problem restricted \n",
" # to the active set\n",
" \n",
" X_E = X1[:,active]\n",
" restricted_loss = rr.logistic_loss(X_E, Y)\n",
" unpenalized = restricted_loss.solve(min_its=200) # \\bar{\\beta}_E\n",
" \n",
" if approx == 'unpenalized':\n",
" lin_pred = X_E.dot(unpenalized)\n",
" elif approx == 'truth':\n",
" lin_pred = X.dot(beta)\n",
" else:\n",
" lin_pred = X_E.dot(soln[active])\n",
" \n",
" pi_E = np.exp(lin_pred) / (1 + np.exp(lin_pred))\n",
" W_E = pi_E * (1 - pi_E)\n",
" Q_E = np.dot(X_E.T, W_E[:, None] * X_E)\n",
" \n",
" # check that the penalized coefficients satisfy active \n",
" # KKT conditions\n",
" affine_active = active_signs[1:] * (unpenalized - lam * np.linalg.solve(Q_E, active_signs))[1:]\n",
" \n",
" # how many things violate the KKT conditions?\n",
" \n",
" return (affine_active < 0).sum()\n",
"\n",
"simulate(n=100, use_glmnet=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using glmnet to solve the problem"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 4\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for i in range(ntrial):\n",
" failed += simulate(n=100, use_glmnet=True) > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using regreg"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 4\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for _ in range(ntrial):\n",
" failed += simulate(use_glmnet=False) > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## More features\n",
"\n",
"Let's up the features to $p=200$."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 28\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for _ in range(ntrial):\n",
" failed += simulate(p=200, use_glmnet=True) > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using true parameters for quadratic\n",
"\n",
"Let's expand around $\\beta_E^*$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 4\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for _ in range(ntrial):\n",
" failed += simulate(p=50, use_glmnet=True, approx='truth') > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### More features"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 29\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for _ in range(ntrial):\n",
" failed += simulate(p=200, use_glmnet=True, approx='truth') > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Yet another quadratic\n",
"\n",
"Let's try expanding around $\\hat{\\beta}_E$."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 29\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for _ in range(ntrial):\n",
" failed += simulate(p=200, use_glmnet=True, approx=None) > 0\n",
"\n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Increasing sample size improves the problem\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 0\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for i in range(ntrial):\n",
" failed += simulate(n=500, p=50, use_glmnet=True) > 0\n",
"\n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 3\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for i in range(ntrial):\n",
" failed += simulate(n=500, p=200, use_glmnet=True) > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of failures out of 200: 0\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"ntrial = 200\n",
"failed = 0\n",
"\n",
"for i in range(ntrial):\n",
" failed += simulate(n=1000, p=200, use_glmnet=True) > 0\n",
" \n",
"print 'Total number of failures out of %d: %d' % (ntrial, failed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pure R code\n",
"\n",
"In principle, I think the code below should work but I keep getting a few errors. Maybe you can fix it?\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%%R\n",
"library(glmnet)\n",
"\n",
"simulate = function(n=100,\n",
" p=80,\n",
" s=5,\n",
" snr=5,\n",
" rho=0.3,\n",
" lam_frac=0.8,\n",
" X=NULL,\n",
" Z=NULL,\n",
" lam=NULL) {\n",
"\n",
" # create some data\n",
" # where the logistic model is\n",
" # actually true\n",
" \n",
" if (is.null(X)) {\n",
" X = sqrt(1-rho) * matrix(rnorm(n*p), n, p) + sqrt(rho) * outer(rnorm(n), rep(1, p))\n",
" X = scale(X)\n",
" X = X / sqrt(n)\n",
" }\n",
" \n",
" if (is.null(Z)) {\n",
" beta = matrix(0, p)\n",
" beta[1:s] = snr\n",
"\n",
" eta = X %*% beta\n",
" pi = exp(eta) / (1 + exp(eta))\n",
"\n",
" Z = rbinom(n, 1, pi)\n",
" }\n",
" \n",
" # choose a fixed value of \\lambda\n",
" # if not provided\n",
"\n",
" if (is.null(lam)) {\n",
" Z0 = matrix(rbinom(n*5000, 1, 0.5), n, 5000)\n",
" lam = as.numeric(lam_frac * mean(apply(abs(t(X) %*% Z0), 2, max)) / n)\n",
" }\n",
" \n",
" G = glmnet(X,\n",
" Z,\n",
" family='binomial',\n",
" #intercept=FALSE,\n",
" standardize=FALSE)\n",
"\n",
" # fails here for some reason\n",
"\n",
" print('makes it to here')\n",
" soln = coef(G, s=lam, exact=TRUE)\n",
" print('but fails above')\n",
" soln = soln[2:length(soln)] # get rid of Intercept column\n",
" active = which(soln != 0)\n",
"\n",
" X_E = X[,active]\n",
" \n",
" unpenalized = glm.fit(X_E, Z, family=binomial('logit'))$coef \n",
" linpred = X_E %*% unpenalized\n",
" \n",
" pi_E = exp(linpred) / (1 + exp(linpred))\n",
" W_E = as.numeric(pi_E * (1 - pi_E))\n",
" Q_E = t(X_E) %*% (diag(W_E) %*% X_E)\n",
"\n",
" active_signs = sign(soln)[active]\n",
" affine_active = active_signs * (unpenalized - lam * solve(Q_E, active_signs))\n",
"\n",
" return(sum(affine_active < 0))\n",
" \n",
"}\n",
"\n",
"simulate()\n",
"\n",
"failures = 0\n",
"\n",
"for (i in 1:20) {\n",
" failures = failures + (simulate() > 0)\n",
"}\n",
"print(failures)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
},
"latex_envs": {
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 0
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment