Skip to content

Instantly share code, notes, and snippets.

Created September 18, 2014 21:58
Show Gist options
  • Save alexstorer/177a3bfcb3c7857d862d to your computer and use it in GitHub Desktop.
Save alexstorer/177a3bfcb3c7857d862d to your computer and use it in GitHub Desktop.
Regression in Python
Display the source blob
Display the rendered blob
"metadata": {
"name": "",
"signature": "sha256:24223f6595c3dfb9697e041ee5edebdf964c96a58dba27ef78e7764643f9a675"
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
"cells": [
"cell_type": "markdown",
"metadata": {},
"source": [
"### Demonstration of multiple regression in Python\n",
"The file `fn.txt` lives online here:\n",
"The first thing we must do is load the numpy library. We call it `np` for convenience. Any time I don't know what to do with numpy (which is pretty much all the time) I just google it. To load a text file, for example, the Google returned this link:\n",
"From the numpy documentation. Handy!\n",
"We can also use this guide to convert from Matlab to numpy, but I haven't always found it to be perfect:\n",
"The numpy folks also have a guide to give you a heads up if you're making the switch:\n",
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np"
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
"cell_type": "code",
"collapsed": false,
"input": [
"f = np.loadtxt(\"/tmp/fn.txt\")"
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 104
"cell_type": "code",
"collapsed": false,
"input": [
"language": "python",
"metadata": {},
"outputs": [
"metadata": {},
"output_type": "pyout",
"prompt_number": 105,
"text": [
"array([[-0.57490156, -1.91351072, -3.39301871],\n",
" [ 0.65598603, 0.22072469, 1.47385001],\n",
" [ 0.43658005, -0.47635545, -2.60489563],\n",
" [ 0.09937458, -0.39675894, 4.29225966],\n",
" [-1.47921372, -0.52771854, -6.01904147],\n",
" [ 0.02433745, 0.45487558, 0.61437382],\n",
" [ 0.923521 , 1.03720213, -0.1912883 ],\n",
" [ 1.87006443, 1.47769287, 6.15090212],\n",
" [-2.09596711, -1.24400998, -1.70133112],\n",
" [-0.44377387, 0.06032462, 0.39509815],\n",
" [ 0.01500656, -0.23688102, 2.2656384 ],\n",
" [ 1.77983053, 4.26975795, 4.97467405],\n",
" [ 0.25642042, -2.3156647 , 1.76530788],\n",
" [ 1.35945287, -0.67858213, 3.75902308],\n",
" [-0.59986296, -0.47666583, -1.2284126 ],\n",
" [-0.66103391, -1.80281943, -2.49601437],\n",
" [ 1.90396449, 0.75357797, 2.99043148],\n",
" [-0.16151286, 0.50255836, -0.81288138],\n",
" [ 0.4664114 , 0.84191455, 1.44697245],\n",
" [-0.43481598, -0.63350808, -1.79533482],\n",
" [-0.83276604, -0.80947849, 0.0683582 ],\n",
" [ 1.0480383 , 1.45046956, 0.25768814],\n",
" [-0.48164113, -0.38361531, -0.17502925],\n",
" [-0.60636319, -1.91081317, -1.47488151],\n",
" [ 0.90443372, 1.84000747, 3.88174901],\n",
" [-0.50978463, 0.35992185, -2.5341687 ],\n",
" [-0.61345175, 0.73248085, 1.76839231],\n",
" [-0.61422979, -3.12200409, -0.13936483],\n",
" [-1.41841266, -0.236586 , -2.87044634],\n",
" [ 1.01300264, -0.88717944, 2.25242043],\n",
" [ 0.24827838, -1.34982394, -1.17282819],\n",
" [ 0.02880206, 0.12824574, 1.33568278],\n",
" [-0.63395938, 0.7085945 , -2.86663276],\n",
" [ 0.73879757, 0.90078908, 3.87506212],\n",
" [-1.69888806, -2.51282168, -1.83977286],\n",
" [ 0.5857642 , 0.78378598, 4.72343417],\n",
" [ 0.14539581, 1.00446406, -4.39537347],\n",
" [-0.3089474 , -1.94391481, 0.2244121 ],\n",
" [ 0.08155905, 1.45149204, -0.60563448],\n",
" [ 0.70085849, -0.88537122, 2.00585174],\n",
" [ 0.45432095, 0.12417837, 0.47614208],\n",
" [-0.49809093, -1.0244259 , -2.23859181],\n",
" [ 0.42854436, 0.63110684, -1.65305689],\n",
" [ 0.20776615, 1.70997573, -1.72362789],\n",
" [-0.60337805, -0.12686956, -2.44121784],\n",
" [ 0.87704736, 1.42309737, 3.10503514],\n",
" [ 0.17901089, 0.95111581, 2.07639851],\n",
" [ 0.11356608, -0.51775454, -1.13834377],\n",
" [ 0.89017288, 0.15716723, -0.50629286],\n",
" [-0.37527344, -0.29820463, -3.16137262],\n",
" [-0.12692828, -1.59778286, -5.49591611],\n",
" [ 0.05643522, 0.22471085, -0.66622035],\n",
" [ 0.69279882, 1.58794651, -1.73643055],\n",
" [-1.21437221, 0.92309239, -2.69699981],\n",
" [-0.89742822, -3.20130424, -3.05975339],\n",
" [ 0.73518875, -0.01617594, 3.82356332],\n",
" [-0.5947668 , -1.90209894, -2.42781527],\n",
" [-0.28923092, -1.06160602, -0.0322157 ],\n",
" [-0.21630749, 0.83317469, -1.20135793],\n",
" [-0.18545978, 1.70683545, -2.1579663 ],\n",
" [-1.26635027, 1.54552983, -2.57887681],\n",
" [-0.37633385, 1.89658346, -0.09618742],\n",
" [-1.71843821, -1.2162742 , -3.96504179],\n",
" [ 0.10448439, 0.89011683, 1.18702495],\n",
" [-0.07698252, 0.84281102, 0.49119564],\n",
" [ 0.10519728, -0.46324068, 1.03383783],\n",
" [ 1.09575593, 1.18847953, 3.89683228],\n",
" [-0.18691149, -1.40007452, -1.40050641],\n",
" [-1.25889559, -0.16652752, -0.92006681],\n",
" [-1.1039712 , -0.59517726, 2.11959324],\n",
" [ 1.60616567, 0.13908959, 1.57408198],\n",
" [ 0.94393435, 1.10878304, 4.78541271],\n",
" [-0.15829364, -1.20972407, 0.0301712 ],\n",
" [-0.88485128, -0.3003857 , -3.88537044],\n",
" [-0.19779125, 1.64214349, 0.29952011],\n",
" [-0.55194435, -2.38045291, 0.49360922],\n",
" [ 1.31166668, 0.38692144, 0.48485089],\n",
" [-1.18984975, -1.97702907, -1.76762911],\n",
" [ 0.05158668, 0.57432922, -2.09170082],\n",
" [-0.29061381, 0.19771936, -2.96849905],\n",
" [ 0.05973922, -2.89138948, -2.1855287 ],\n",
" [ 0.0650401 , -0.17037504, -0.85841545],\n",
" [-0.36338606, 0.37380396, -3.79889257],\n",
" [-1.28522827, 0.48634568, -2.34494568],\n",
" [-0.06859874, -0.59902582, 1.75510453],\n",
" [-1.34416803, -0.25045908, 0.05798853],\n",
" [-0.66504755, -0.08982944, 1.77270165],\n",
" [ 0.33986993, 0.21612558, -2.11661965],\n",
" [ 1.67843577, 0.55810764, 3.12369016],\n",
" [-0.83469781, 0.06949099, -2.33815019],\n",
" [ 0.38124986, -0.4272168 , -1.80814763],\n",
" [ 0.23099005, 1.08864373, 4.17000617],\n",
" [-0.22241665, 0.16425892, -1.88686209],\n",
" [ 0.42553115, 2.46916041, 3.91101421],\n",
" [ 1.13058716, 0.83237927, -0.68621658],\n",
" [-0.23167494, 0.78469492, 3.78702313],\n",
" [ 0.96832249, 0.58709605, 3.62006603],\n",
" [-2.56461595, -1.64636548, -8.35215179],\n",
" [ 0.9270586 , -0.7307362 , 3.5471416 ],\n",
" [-1.44429727, -1.07769535, -5.73425796]])"
"prompt_number": 105
"cell_type": "markdown",
"metadata": {},
"source": [
"Your code in Matlab was as follows:\n",
"y=fn(:,1); % select outcome to be the first variable\n",
"x=fn(:,2:3); % select regressors to be the last two variables\n",
"n=length(y); % number of observations\n",
"x=[ones(n,1),x]; % add constant term to regressors\n",
"beta=inv(x'*x)*(x'*y) % calculate least squares regression coefficients\n",
"We can use the default `array` type in numpy to do these operations."
"cell_type": "code",
"collapsed": false,
"input": [
"y = f[:,0]\n",
"x = f[:,1:3]\n",
"n = y.shape[0]\n",
"x = np.hstack((np.ones((n,1)),x))\n",
"beta =,x)),,y))"
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 128
"cell_type": "code",
"collapsed": false,
"input": [
"language": "python",
"metadata": {},
"outputs": [
"metadata": {},
"output_type": "pyout",
"prompt_number": 129,
"text": [
"array([-0.01597851, 0.161309 , 0.18426684])"
"prompt_number": 129
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also use the `matrix` type in numpy. It makes matrix multiplication easier, but everything else is harder.\n",
"Note that when converting `y` to a matrix, we must use the transpose (`.T`) to force the dimensions to align.\n",
"The `.I` property contains the inverse of the matrix."
"cell_type": "code",
"collapsed": false,
"input": [
"y = f[:,0]\n",
"x = f[:,1:3]\n",
"n = y.shape[0]\n",
"x = np.hstack((np.ones((n,1)),x))\n",
"y = np.matrix(y).T\n",
"x = np.matrix(x)\n",
"beta = (x.T*x).I * (x.T*y)"
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 124
"cell_type": "code",
"collapsed": false,
"input": [
"language": "python",
"metadata": {},
"outputs": [
"metadata": {},
"output_type": "pyout",
"prompt_number": 125,
"text": [
" [ 0.161309 ],\n",
" [ 0.18426684]])"
"prompt_number": 125
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternately, we can use scipy to perform regression without doing matrix operations directly. There is a function called `scipy.stats.linregress`, but that only does single regression. To do multiple regression, we can use the `sklearn` package and its `linear_model` function. It took some googling to figure this out -- this is why I haven't moved away from R yet for data analysis. R is just much, much more mature."
"cell_type": "code",
"collapsed": false,
"input": [
"from sklearn import linear_model\n",
"y = f[:,0]\n",
"x = f[:,1:3]\n",
"regr = linear_model.LinearRegression()\n",
"r = x, y )\n",
"print \"Coefficients:\", r.coef_\n",
"print \"Intercept:\", r.intercept_"
"language": "python",
"metadata": {},
"outputs": [
"output_type": "stream",
"stream": "stdout",
"text": [
"Coefficients: [ 0.161309 0.18426684]\n",
"Intercept: -0.0159785056214\n"
"prompt_number": 176
"metadata": {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment