Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active April 27, 2020 04:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/611d8f180ef0f0baddaa539e29a4200e to your computer and use it in GitHub Desktop.
Save gngdb/611d8f180ef0f0baddaa539e29a4200e to your computer and use it in GitHub Desktop.
Least Squares in PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adapting [this excellent blog post](http://drsfenner.org/blog/2015/12/three-paths-to-least-squares-linear-regression/) to show how to write a least squares algorithm in PyTorch."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"def create_data(n=5000, p=100):\n",
" ''' \n",
" n is number cases/observations/examples\n",
" p is number of features/attributes/variables\n",
" '''\n",
" X = torch.rand(n,p)*10.\n",
" coeffs = (torch.rand(p)*10.).view(p,1)\n",
" def f(X): return X.mm(coeffs)\n",
"\n",
" noise = torch.randn(n,1)\n",
" Y = f(X) + noise\n",
" Y = Y.view(n,1)\n",
" \n",
" return X,Y"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"X,Y = create_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Via Cholesky\n",
"\n",
"Unfortunately, there is no `torch.solve`, but there is `torch.gesv` which [solves systems of linear equations](https://pytorch.org/docs/stable/torch.html?highlight=least%20squares#torch.gesv). We have to reverse the arguments though:\n",
"\n",
"```\n",
"torch.gesv(B,A)\n",
"AX = B\n",
"```\n",
"\n",
"```\n",
"numpy.linalg.solve(a,b)\n",
"ax = b\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"XtX, XtY = X.permute(1,0).mm(X), X.permute(1,0).mm(Y)\n",
"betas_cholesky, _ = torch.gesv(XtY, XtX)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([100, 1])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"betas_cholesky.size()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Via QR\n",
"\n",
"PyTorch does have a [`torch.qr`](https://pytorch.org/docs/stable/torch.html?highlight=qr#torch.qr), which is good, but the blog post uses `gels`. Luckily that is also available:"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"betas_qr,_ = torch.gels(Y,X)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"betas_qr = betas_qr[:X.size(1)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Via SVD\n",
"\n",
"In PyTorch there is no `lstsq` function as used in this section. There is an SVD though, so we can put together a least squares that way."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"U, S, V = torch.svd(X)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"S_inv = (1./S).view(1,S.size(0))\n",
"VS = V*S_inv # inverse of diagonal is just reciprocal of diagonal\n",
"UtY = torch.mm(U.permute(1,0), Y)\n",
"betas_svd = torch.mm(VS, UtY)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing Betas"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.789241313934326 4.789309501647949 4.789262294769287\n",
"9.147521018981934 9.14731502532959 9.147207260131836\n",
"5.609707355499268 5.60968017578125 5.609684467315674\n",
"7.863295078277588 7.86319637298584 7.863215446472168\n",
"7.966612815856934 7.966621398925781 7.966643333435059\n",
"0.9358347058296204 0.9359120726585388 0.935868501663208\n",
"6.460849285125732 6.460752487182617 6.46080207824707\n",
"2.361510753631592 2.3616764545440674 2.3616724014282227\n",
"3.4893341064453125 3.4897570610046387 3.489757776260376\n",
"0.40589192509651184 0.4058658182621002 0.4058196544647217\n",
"6.3950676918029785 6.395297527313232 6.395309925079346\n",
"4.863685131072998 4.863367557525635 4.863370418548584\n",
"9.596863746643066 9.596785545349121 9.596807479858398\n",
"3.2956292629241943 3.2953946590423584 3.295461654663086\n",
"0.7079331874847412 0.7079015374183655 0.7079324722290039\n",
"7.927335739135742 7.927225589752197 7.92726993560791\n",
"1.7596901655197144 1.7599283456802368 1.759899616241455\n",
"3.1518752574920654 3.151834487915039 3.1518335342407227\n",
"6.198609352111816 6.198548793792725 6.198549270629883\n",
"2.877772092819214 2.877809762954712 2.877819776535034\n",
"0.1144978478550911 0.11482840776443481 0.11482787132263184\n",
"1.659272313117981 1.6594208478927612 1.6593914031982422\n",
"8.555513381958008 8.555625915527344 8.555619239807129\n",
"6.668213844299316 6.668307304382324 6.668299198150635\n",
"4.201225280761719 4.200991153717041 4.201006889343262\n",
"8.882617950439453 8.882568359375 8.882552146911621\n",
"1.1241998672485352 1.1241655349731445 1.1241430044174194\n",
"9.722527503967285 9.722660064697266 9.722660064697266\n",
"7.96357536315918 7.96376895904541 7.963817119598389\n",
"4.061668395996094 4.061661243438721 4.061673164367676\n",
"0.4477618336677551 0.4477052688598633 0.44772469997406006\n",
"8.820137977600098 8.819872856140137 8.819879531860352\n",
"2.5416290760040283 2.5413544178009033 2.5413389205932617\n",
"4.627619743347168 4.627681732177734 4.627681255340576\n",
"0.7478161454200745 0.7477755546569824 0.7478164434432983\n",
"7.434598922729492 7.434365749359131 7.434381008148193\n",
"5.782316207885742 5.782223701477051 5.7822265625\n",
"7.303110122680664 7.30283260345459 7.302789688110352\n",
"9.179882049560547 9.180000305175781 9.180000305175781\n",
"3.4134645462036133 3.413600206375122 3.41367506980896\n",
"8.272011756896973 8.272011756896973 8.272056579589844\n",
"8.254677772521973 8.254998207092285 8.254950523376465\n",
"4.843873500823975 4.843794822692871 4.8437299728393555\n",
"8.814054489135742 8.814101219177246 8.814138412475586\n",
"3.537532329559326 3.537382125854492 3.537411689758301\n",
"3.009495735168457 3.009181022644043 3.0092389583587646\n",
"2.756869077682495 2.7569947242736816 2.7569937705993652\n",
"1.9333038330078125 1.93343985080719 1.9334814548492432\n",
"9.179861068725586 9.179898262023926 9.179922103881836\n",
"7.7894978523254395 7.789270401000977 7.789253234863281\n",
"4.843989372253418 4.843654155731201 4.843679904937744\n",
"6.666652202606201 6.666760444641113 6.666749477386475\n",
"6.028740882873535 6.028807640075684 6.0288286209106445\n",
"9.00769329071045 9.007808685302734 9.007761001586914\n",
"0.13813860714435577 0.13820038735866547 0.13817358016967773\n",
"9.304071426391602 9.304142951965332 9.304139137268066\n",
"4.384835720062256 4.384726047515869 4.384748458862305\n",
"4.6577630043029785 4.657683849334717 4.657710552215576\n",
"9.897286415100098 9.897216796875 9.897197723388672\n",
"1.960221290588379 1.960351586341858 1.960319995880127\n",
"9.509464263916016 9.509519577026367 9.509477615356445\n",
"1.4898333549499512 1.4899165630340576 1.4899344444274902\n",
"8.848724365234375 8.848979949951172 8.848989486694336\n",
"6.172708511352539 6.172853469848633 6.172830581665039\n",
"7.406793594360352 7.406737804412842 7.406713008880615\n",
"1.8598474264144897 1.8596168756484985 1.8596140146255493\n",
"7.90094518661499 7.901154041290283 7.901152610778809\n",
"1.0882967710494995 1.0882912874221802 1.0882902145385742\n",
"7.895720958709717 7.895816326141357 7.895867347717285\n",
"5.307201385498047 5.307127475738525 5.307128429412842\n",
"1.5453896522521973 1.545369029045105 1.5453554391860962\n",
"7.144125461578369 7.1442389488220215 7.144256591796875\n",
"8.263419151306152 8.26330280303955 8.263258934020996\n",
"4.024933815002441 4.025058269500732 4.025018215179443\n",
"8.888937950134277 8.889105796813965 8.889098167419434\n",
"4.1783928871154785 4.178523063659668 4.178526401519775\n",
"9.600228309631348 9.600030899047852 9.600034713745117\n",
"5.944829940795898 5.944741725921631 5.944738864898682\n",
"5.6164655685424805 5.616551876068115 5.616508483886719\n",
"4.597523212432861 4.597728729248047 4.597701549530029\n",
"7.265918254852295 7.265911102294922 7.265933036804199\n",
"4.980377197265625 4.980417728424072 4.98041296005249\n",
"8.399260520935059 8.399263381958008 8.399232864379883\n",
"2.6021928787231445 2.602515459060669 2.6025009155273438\n",
"6.948577404022217 6.9487457275390625 6.94873046875\n",
"0.7757428884506226 0.7756509780883789 0.7756406664848328\n",
"1.3978784084320068 1.3978636264801025 1.3978710174560547\n",
"8.347896575927734 8.34741497039795 8.347456932067871\n",
"3.9364943504333496 3.93650484085083 3.936490297317505\n",
"5.9613776206970215 5.96109676361084 5.961121082305908\n",
"3.430814743041992 3.430952310562134 3.4309439659118652\n",
"9.821599006652832 9.821892738342285 9.821885108947754\n",
"8.789693832397461 8.789381980895996 8.78941822052002\n",
"3.9870569705963135 3.9870316982269287 3.9870190620422363\n",
"5.418898582458496 5.419239044189453 5.41928768157959\n",
"6.473894119262695 6.474045753479004 6.474076747894287\n",
"3.898988962173462 3.898843288421631 3.8988606929779053\n",
"0.7752768397331238 0.7750235199928284 0.7750210762023926\n",
"9.980328559875488 9.980111122131348 9.98007583618164\n",
"0.684233546257019 0.6842225193977356 0.6841509342193604\n"
]
}
],
"source": [
"for i in range(betas_cholesky.size(0)):\n",
" print(betas_cholesky[i,0].item(),\n",
" betas_qr[i,0].item(),\n",
" betas_svd[i,0].item())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment