Skip to content

Instantly share code, notes, and snippets.

@edwardeasling
Last active February 27, 2019 22:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edwardeasling/d806a5fdf67adfdcf2b8dd7d9ec01273 to your computer and use it in GitHub Desktop.
Save edwardeasling/d806a5fdf67adfdcf2b8dd7d9ec01273 to your computer and use it in GitHub Desktop.
lesson2SGD_withmomentum
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from fastai.basics import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this part of the lecture we explain Stochastic Gradient Descent (SGD) which is an **optimization** method commonly used in neural networks. We will illustrate the concepts with concrete examples."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Linear Regression problem"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The goal of linear regression is to fit a line to a set of points."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"n=100"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.5494, 1.0000],\n",
" [-0.9676, 1.0000],\n",
" [-0.6262, 1.0000],\n",
" [-0.1999, 1.0000],\n",
" [ 0.7904, 1.0000]])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.ones(n,2) \n",
"x[:,0].uniform_(-1.,1)\n",
"x[:5]"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3., 2.])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = tensor(3.,2); a"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"y = x@a + torch.rand(n)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x[:,0], y);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You want to find **parameters** (weights) `a` such that you minimize the *error* between the points and the line `x@a`. Note that here `a` is unknown. For a regression problem the most common *error function* or *loss function* is the **mean squared error**. "
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def mse(y_hat, y): return ((y_hat-y)**2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Suppose we believe `a = (-1.0,1.0)` then we can compute `y_hat` which is our *prediction* and then compute our error."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"a = tensor(-1.,1)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(7.5648)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_hat = x@a\n",
"mse(y_hat, y)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X90HeV5J/DvI3GBa0Iku6hZLCxszrJmwy8r6GTZuqcJOMuPBoxqCJBuWmhonDTpD8rGi2l6wLCbxazPKWlPcrbrpoR0kxIMGNXgUEOwKbucmFaqsMGAgwMBLNPgBOQGW9jX0rN/zIw0mjs/77wz987c7+ccjqV75868uVKe++qZ531eUVUQEVF5dDR7AEREZBYDOxFRyTCwExGVDAM7EVHJMLATEZUMAzsRUckwsBMRlQwDOxFRyTCwExGVzDHNuOhJJ52kCxcubMaliYgKa2Rk5Geq2hN1nJHALiLdAL4J4CwACuCzqvrDoOMXLlyI4eFhE5cmImobIvJ6nONMzdj/HMDfq+pVInIsgDmGzktERAmlDuwi8kEAvwbgegBQ1SMAjqQ9LxERNcbEzdPTAOwH8C0RGRWRb4rICd6DRGSliAyLyPD+/fsNXJaIiPyYCOzHAPgIgP+lqv0ADgJY7T1IVder6oCqDvT0ROb+iYioQSYC+14Ae1X1Wfv7B2EFeiIiaoLUgV1V/wXAmyKy2H5oGYAX056XiIgaY6oq5g8AfNeuiHkVwO8YOi8RUWENjY5h3Zbd2Dc+gfndVay6eDEG+3szv66RwK6qzwEYMHEuIqKiGxodw+2P7MK7h2rTj42NT+CWjc8DQObBvSkrT4mIiijODHxodAy3bHweE7XJutdP1CaxbstuBnYiolbgDdhBM/B1W3b7BnXHvvGJbAcKNgEjIorFL2A7M3C3qMA9v7tqfGxeDOxERDEEBWzv42GBu1rpxKqLFwc+bwoDOxFRDEEB2/v4qosXo1rprDuuu1rBnSvOLk5VDBFRmfjdJF118eK6m6KVDsGhI0exaPXmupupzShzdIiq5nYxx8DAgLJtLxG1Ir+qlmqlE3euOBvATMDuqlZw8MhR1Ca17risgriIjKhqZGk5UzFERC5hN0kH+3vxzOoL8draT+KE446ZFdTdxzUbAzsRkUvcm6RBx42NT2BodMz4uJJgYCcicol7kzSs+uWWjc83NbgzsBMRufhVtQiAC87oiTzO0eyUDAM7EZHLYH8vrjyvF+J6TAE8NDI2axY+2N87fUPVTzNTMix3JKJSStNZcdvL++GtF/T2eXHOHyavpl9enLETUek4JYtj4xNQzPR1iTuDjrqB6j5/mGalZBjYiah04vZ1CRJ1AzWq0ZdbHk2/vBjYiah04pYsBvG7Meru85IkWOfR9MuLgZ2ISiduyWIQ58Zob3cVAqC3uzprRWnQecTzfV5Nv7x485SISsevr0vSIDvY3xt40zPo/Fee14ttL+9vWo8YBwM7EZVO1o24WqHRVxg2ASMiKoi4TcCMzNhF5CcAfgFgEsDROBcmIqJsmEzFXKCqPzN4PiIiagCrYoiISsZUYFcAj4vIiIisNHROIiJqgKlUzFJV3ScivwzgCRF5WVWfdh9gB/yVANDX12foskRE5GVkxq6q++x/3wbwMICP+hyzXlUHVHWgp6fH+zQRERmSOrCLyAkicqLzNYCLALyQ9rxERNQYE6mYDwF4WESc8/2tqv69gfMSEVEDUgd2VX0VwLkGxkJERAaw3JGIqGTYK4aISiPNrkllwsBORKXg7GrkdFx0dk0C8t+artmYiiGiUki7a1KZcMZORKWQdtekMqVxGNiJqBTmd1d9N5f27nbkF8ABlCqNw1QMEZVC1D6lwEwefmx8AoqZAH77I7tKlcbhjJ2IMpF3aiPOrkZBeXjvY44km1a3EgZ2IjIurEIFyHbLurBzJQ3UcTe/bjUM7ERkXNDM+PZHduH92lQmuew4fyEE5eG7qxUcPjqVavPrVsIcOxEZFzQzfvdQLZNcdlDufGh0bPr5pWu3Ymx8AuJ5bbXSiTXLz8SdK85Gb3cVAqC3u4o7V5xdyBunAGfsRJSBoJlxkLHxCSxdu7XhtExUDbs7LaQAxP631zOzL2og9+KMnYiMC6pQ6a5WAl/jnWUnEVbD7hf0naD+zOoLSxPM3RjYici4wf5e39TGmuVn1gV8t0bTMkE3Oed3V1MvXCoipmKIKBNhFSrrtuwOTNU0EnBXXbx4VroFmLn5GXStola8xMHATkQNa6RW3Qn4zs1Mr0YCblQNe1DQLytR1dwvOjAwoMPDw7lfl4jScQfyrmoFB48cRW1yJoZUK5248rxebHt5f2Sw99a6O6/PohqlLH1gRGREVQcij2NgJ6I4/AKxH6fixOEN1t4PBxFg/FCt0AE3L3EDO1MxRBSLX3WJH+9U0bkhOtjfW/fhMD5RQ7XSibuvWcKAbhCrYogoljRVJM5r2TM9HwzsRBRLnJua3lWdjq5qJfBmKVDu0sNmMJaKEZFOAMMAxlT1MlPnJaJsuXPe3XMqUAUOTNTnvP1KCisdgg8cf8x0jvyCM3rw0MhY3TEHjxzF+EQtcAxlLj1sBpM59j8C8BKADxo8JxFlyJvzfvfQTPD1NuiK0xYXAAZOnTfrmENHjs46r1fZSw+bwUhVjIicAuDbAL4K4KaoGTurYohaQ1h6xOEsvW/UotWb626ous/NSpj48q6K+RqA/wrgxJABrQSwEgD6+voMXZaI0oiT206b/w5qCJb2A4OCpb55KiKXAXhbVUfCjlPV9ao6oKoDPT09aS9LRAbEyW13zwlu3BVHnC3ryCwTVTFLASwXkZ8A+B6AC0XkOwbOS0QZixNc02ZrgxqCMf2SHaMrT0Xk4wC+zBw7UTayWBq/5PbHQytWAOAnaz+Z6hpkRtwcO+vYiQoiapegRkW10hX72lQcRgO7qj7FGnaibGS1atNJlQRtgqH2tak4OGMnKogsN4wY7O/Fc7ddlPja1JoY2IkKImyXIFN6c7gGZY+Bnagg0pYNDo2OYenarVi0ejOWrt06K2/uPDc2PlHX74WlicXDtr1EBRF3Sb8fb+sAd7sAYPYOQ4qZnupcGVpMDOxEBRK2j2iYqBuv3uecoM6VocXEwE7UBhq58cobpsXFwE7UBoL6tTg3Rf2e655j9VAv+j6h7Yg3T4naQNiNV7/nKp2C994/anwxFOWDM3aiNhDnxqv7uYOH6zfGcO9dSq3NaK+YuNgrhtpNFj1eshTUQ10AvMa+MU2Tdz92IgoQVmo42N/bkkE/KidPrY2BnShjUaWGYUHfpCQfIH77m3KhUnEwsBNlwB1Eg5Kd+8YnQoO+ycAe9VeDV5rFUNR8DOxEhnmDaJD53dVMG3u5NfIB0uhiKGo+ljsSGeYXRL2ctEYejb2AbDtDUuthYCcyLCpYdopMbw2X136geX2AUGtgYCcyLCpYTqlOpzjy2g+UG0q3F+bYiQzzqyhx8wb+rHLZ3iqYK8/rxbaX9/NmaBtgYCcyzAmWt2zciYnaVN3zF5zRk/kY/KpgHhoZy+SvAWo9TMUQZWCwvxfzTjjO97nNO9/K/PpZ7Y9KxZA6sIvI8SLyjyKyQ0R2icjtJgZGVHRBN1HfPVTLvJkWq2Dam4kZ+2EAF6rquQCWALhERM43cF6iQgu7iZr1zJlVMO0tdWBXy3v2txX7v/w7ixG1kKHRMRw8fDTw+axnzqyCaW9Gbp6KSCeAEQD/FsA3VPVZE+clKqI4K0+znjmzJUB7MxLYVXUSwBIR6QbwsIicpaovuI8RkZUAVgJAX1+ficsStaSolad5zZzZEqB9Ga2KUdVxAE8BuMTnufWqOqCqAz092Zd7ETVLWJolqwVIRG6pZ+wi0gOgpqrjIlIF8AkAd6UeGVFBBfUy7+2u4pnVFzZhRNRuTKRiTgbwbTvP3gFgg6o+auC8REbF6UduYtMLE73MW3HzDSoObo1HbcHvhma10jkrLRLnmCTXcwJz95wKVIEDE7VYQdrkOKhc4m6Nx5WnVGpDo2NYunYrbrz/uciVmCZXaw729+KZ1Rfi7muW4P3aFMYnalDMbHARtkCJq0YpLQZ2Ki1n5uuX73a4b3RmsVqzkSDNVaOUFgM7lVacDS/c9eRZrNZsJEhz1SilxcBOpRU1w/Xe0MxitWYjQZqrRiktBnYqrbDg6VdPnsWmF40E6bw236DyYj92Kq2gssOwIOldiu/kwhsNqo0u7eeqUUqDgZ1Kx11q2FWt4PhKB8YPNVZq6FSxAOmCO4M05YmBnUrFG5jHJ2qoVjpx9zVLYgXXsCoWBmcqCgZ2KpW0gTmqisW7IvSCM3q4jyi1HAZ2KoS4S+zT1oAH9XmZ3131TdN8Z/sb08eYSNsQmcCqGGp57oVGUas309aAh1WxxKmL5wpRagUM7NTykqze9AvMAuCCM+K1ig4rNYw76+cKUWo2pmKo5SVJrwz292L49Xfw3e1vTO/PqAAeGhnDwKnzYqVIgqpYgtI0fscRNRNn7NTykqZXtr28v27TXRMpEr+/Bry4QpRaAQM7tbykqzezaqLll6b5zPl9XCFKLYepGGp5SVdvhlW2mBgLAze1OgZ2KoQkAdXEDkZERcbATqWRppUAUZkwsFMppG0l4D4P9xqlomNgp9xkGTRN9HjJogEYUTOwKoZykWT1aCNMVMJwr1Eqi9SBXUQWiMg2EXlJRHaJyB+ZGBiVS1TQdDadXrR6M5au3Zo44JvYTo57jVJZmEjFHAXwX1T1n0XkRAAjIvKEqr5o4Nz+dm4AHrsZmHjH+r46D7j0LuCcqzO7JCXnTr14Fww59o1PNJwC8d4srXQKapMzV0paCZNlmSRRnlLP2FX1LVX9Z/vrXwB4CUB2CcmdG4ChL84EdcD6euPngDXdwJou4O6zrOOoabyplyDzu6sNpUC85x+fqAEKzJ1TaXixEPcapbIwevNURBYC6AfwrM9zKwGsBIC+vr7GL/LkHcBULeBJO4QceBN45A+trzmLb4o4nRCdoPnH9z/n+3xYCsTv/LUpxZxjj8HorRclHzAa38aOqNUYC+wi8gEADwG4UVX/1fu8qq4HsB4ABgYGwiZx4Q7sjXdcbcL6EDjn6vrUjXQAOgV0LQCW3crgn4GwoCzArKC5bsvuxCmQsHx4muobriylMjAS2EWkAiuof1dVN5o4Z6CuU6wZeRwH9s6kbtyzfJ2yn3fN7AHrg+DAXusaDPipBOWre7ureGb1hbMea2SlaND5u6oVlixS2zNRFSMA/hrAS6r6Z+mHFGHZrUBHJd6xXadEpG5gzewfu9kK8AfeBKAzAZ95+oYlyVeH9UBPen4RsGSR2p6JGftSAL8F4HkRcZKlf6Kq3zdw7nrOLNqdWvFTqVofAhtXRp/T7zzeVM6Td8z+S0E6gfOuBy7L/rOsiMLy1UGpEr9AHnWs97lG8vVEZSOqjae7GzUwMKDDw8NmTuYOutIJ6OTs3PndZ8VP3dQRYMV6a/ZeCwgMAzcwuCfgLW0ErJm23ww9ybGOpWu3xk4BERWNiIyo6kDkcYUP7FH8cuxulSpwTNV/1t61wPo37INBOoHbQv5yaDNRNy6TBN6wY509SP3+Gkj6YUBUFHEDe/l7xfilbrxVMUD9rDxuKkfDS/oAtE0qJ85CoySrO4OOdc4bdh2WLFI7K39gB6zgHqfCxa8qxhuQvSR8qzTs3OCfytFJYPivra/7zi9FRU6cRlxJVncGHdspEnodlixSu2uPwB5HUPBfdmt4jv2868PP++Qdwa8FgJFvATv+duaYA29afyVs/Fz9sYs+Bly3Kfx6TRRnNp6ktDHo2KCFT7xBSmRhYI/iBPtGUylRC6p0yifwB9z3eO0fgP/WA0zWmjKzj8qfx5mNJ0mVBB3byIImonZS/punzZaqKieBiA+aodExrNm0y+qpAqunym2Xn5moydYvDh/F5NTM70ulU7DuqnOnz5HXjUveIKV2FffmKfuxZ23ZrdaN2CCVE8xcx8nZf3W+1QzN1QhtaHQMqx7YMR3UAeDdQzWsenBHYHtcvyZb7qAOALVJxe2P7Jr+vpGFRo3I6zpERcVUTNaiUjl954fn8JOqHbT+dbVLmNz0EF6ubEGHK8UzBcF3Jpdh3ZYv+QbEOE28AOsDwi2vG5e8QUoUjIE9D3GqcqYDvyAwx55UbQJ45EasmDoIkdlPdULx250/wIpD/xfY+fW68fFGJFFxMbC3Anfg96t5n9ZA0K/VB/Xpswlwohz2bYT2w+NPwv848ilsmvrV0NN3V2P27SGi3PDmaZHs3AA8cuNMusWk6jzg6MSslNAR7cR7qKIb72GfnoRX9UP4lY6X0IkpTKID900tw4kr/pwpEaKcsKVAmT16EzBy7+xVr9JhT+anZh6rVIGjh2faFKeksP5mmPX9sScAl32tkAuqiIqGLQVKyCk/HBtfhk75BCZVp/umDPb3utI4rhWsb2yfWeGakjejIwBw5ODs3ar8xsCgT5QrztgLwq922xFVw/3jb30eC1+/Hx2un/WsvHtYI7S4nL47dRU+9n2BrgXA6RcBrzzOoE/UIKZiSiao06EjqC2t3wfC8o7/h9XHbsDJ+DnECbBAyrJLSba7laM6D7j0LgZ4ohiYiimZqPLDoOf96tE3Tf0qRo7/T/79yZ00SnUucOQ9YPJIvAF2nRJ/P1q3iXesD5Q3tnM2T2QIA3tBBPVhcT/vJ0mb3Lp6e2++fN5pwGtPo67k0mlxHNUJM0htAhi+Z+a8vnvRBmykQkR1GNgLwq/ToSNs4+ckbXLrBC2sCrtB2nA6x/Nh4exF6y7BdKqADrwJ/N2XZpd+MqVDNI059gKZqYqZQKdIfVVMwGtybZg1a4GVwVW0sXQA1W4rvcPZPZUQb57StKh2u5nxzuxPvwjY9bBP9U1OHwAM8lRwuQZ2EbkHwGUA3lbVs6KOZ2Bvc34B373ZCGCmBNOXq/ySQZ4KJu+qmHsBfB3A3xg6H7Ww1H8B+OXu/bYHBMx2vgTge4PWPRb3qt4S7ktL7cFIYFfVp0VkoYlz0YympVAixhS1kXRDwjpgeqtiqvOA9w/E20g8TG3COrdz3Udvmr1K1+lx//M9Lb0lIZEXq2JagF8AB5BNAE0pzobVRoVV5jx280yqpnICMFWLX3fvcNfej9zrf8xr/2Bdjy0TqCByC+wishLASgDo6+vL67ItL2gGLFBM1GY378o0gMaUqC4+S34B3xvs4+g6ZebrsL8AnrzD+tedGjrwJjD0Rfua7zLQU8vILbCr6noA6wHr5mle1211QTPgIO4AajpVE+d8qeris+YEe2+Alw67w6Wn+sZZWOVwUj1+Duy1grs33z9Vm7mOO2//xvb6DpyOgRuYt6dMMRWTg7CAmXSm6wTQNLnuNKkfv4VSYQukmqKRhVWAdaM0qBNm3JYJtQng0RutrpdBmLenjJkqd7wPwMcBnATgpwBuU9XAXrHtVO4YtUAoqrmX19euWRL6uqBmYFHjOb7SUbd/adD5WvGmrjHfXm7l1N0qVeDyv2i8ZUKQFX/FvD0lkmu5o6p+2sR5yijqZqPfDDhouc7cORUM9vdiaHQs8MMg6kMiaDxB6R+/vyhKvZH0dZsyapngIyhvz6ZolBJTMRmLutnoBEj3DPiCM3rw0MhY3az6tsvPnJ5xB+kM2uA0YjxBWiJ3nregVI7zWFgHzEoVqL2PWCtpg/L2sZqiMeBTMAb2jMW52eg3Ax44dZ5vumPp2q2hN1cnfVJr7tRJh91jxqu7WsHho1OtnTtvBVEdMJPsWhWat4/RFC1okRW1PQb2jDV6szEo3RE14+71zLC9OXW/oF6tdGLN8jMBoLy586z4ze6d74OqYoDGWh37lXF6F1m5MXffthjYM+aXakkTMMP6svt9YPjl1AErZTOlWjceBnJDLvuz2SWNifL2CZui+c36d26oz91v/Dzw8Oet0k+2Syg1dndsYXHKEh3d1QrWLD+zLjAvWr3ZN0QIgNfWfjKjkVMiaZuidS0A/viF2Y/dfVa8vwQGbrD69Mwq0RRg4LMM+i2IbXsLbmh0DKse2IHa1MzPp9IhWPepcwHE/wug0bJIajK/GT5QP7t3SjG9KZY13Yg165cOa2fzqYi+O+yG2RIY2Atuye2PY3yivq68u1rBc7ddFPs8uW+0QdmKmzePO2NPolIFzv1NlmE2ETezLji/oB72OBC+cIg3RUsirAum27Jbzbc89ivDZK+clsTAXhJRLQYYyNuMt+a+Mmdmf1i3ygn+jwfy/IXv1ytnenEVNyBvFqZicpRkKX7/HY/7LvGfO6eC0VvrUzHMpVMkv01E+s4Hhr4QnWNPJKiqx368Os/6lrP8xJiKaTHem6Fj4xNY9cAOAP4lhrddfiZWPbgDtUnXzdNOwW2Xn+l7/pZpp0uty1uC6RbVuAxA/DLMoGPsx92VPQfeBDZ+zvqPs3pjOpo9gHaxZtOuWRUuAFCbUqzZtMv3+MH+Xqy76lz0dlchsGbe6646N3CGH7T0vy1bAlAy51wN/Mk+YM0Bq/xRfMJC1wKrBLKS4e+Tk8rZuSG7a7QJzthz0sjN0CS58UK006XWFzarB2bvTevXKyfp4iqvsJW0FBsDe4tptCUuq18oF1G9cvwWVyUVp+89hWJgz8ncOZXAm6GOtBtFs/qFcudXfjk9q3dVxSSZybu3K0zCu3NWdR5w6V1tOftnVUxOhkbHfG+GuvPmrGyh0pqe2UcsmgpaSRvn/ENftMov/ZSkEiduVQxvnuYkzs1QVrZQaZ1ztdXPZs0Ba+eorgUAxAq41XnW110LGgvqgPWhERTUAWsWP/EOALUrcVYCa7qsFbolvFlbmlRMltu1mTp3VKqkpTeKJjIl7urZJBLn5SM2ManOtb4v6Ay/FDN2Jzc9Nj4BxUxuemh0rKXP7bXq4sWoVjpnPcbKFqIYGs3LAzObmDzyh3aqSP1n+I/eZGq0mSvFjD1qX9FWPbcXK1uIGrTs1vAcexS/dsizqN0nBzNN0KpzgaOHZ1oytNDN2lIE9ixz03nnvVnZQtQAJ5i6q2KM09lN0LzXmXjHWkHrbGbSxJW0RlIxInKJiOwWkT0istrEOZPIctUlV3QSFcQ5VwM3v2bdoA28SQvre7dK1fVclBhVhDpl/eu0S1jTnfuN2tSBXUQ6AXwDwKUAPgzg0yLy4bTnTSLL3DTz3kQFNV2JM24FfCfor1g/E/CdSpxL78qwXYLnRm0Owd1EKuajAPao6qsAICLfA3AFgBcNnDuWLHPTzHsTlUxYVY5TFXPsHJ+maCnbJQC5tUxIvUBJRK4CcImq/q79/W8B+A+q+vue41YCWAkAfX19573++uuprktElKks2iUAAMT6K6KRV+bYtld8Hqv7tFDV9QDWA9bKUwPXJSLKTmi7BJ+qmLjSlGbGZCKw7wWwwPX9KQD2GTgvEVFrCUrjuFsmSMfMDVSvSnVmY/IMmQjs/wTgdBFZBGAMwLUAftPAeY3KcmUqEbW5wK6XzdkeMHVgV9WjIvL7ALYA6ARwj6r67x7RJH869Dy+u/2N6fxQ0q6JDn44EFEsWbRNSMDIAiVV/T6A75s4l2lDo2Ozgroj6erRtC11iYjyUoqVp2HWbdkdWKDkXj0aNRvPs7UAEVEapQ/sYUv/ndWjcWbjfl0Xwx4nImqWwnd3HBodw9K1W7Fo9WYsXbu1ruti0NJ/AaZXj4bNxh2d4lfVGfw4EVGzFDqwx2mp69cSQAD85/P7pmfjcRp9TQYs5Ap6nIioWQod2OPMtAf7e3HnirNn7Vx09zVL8N8Hz54+Jk6jr96AY4IeJyJqlkLn2OO21I1qhbvq4sWzcuxAfaOvOMcQEbWCQgd2U1vJxWn0xWZgRFQUqZuANWJgYECHh4dTn8dbzQJYs+g7V5wdGHC5yIiIiirPJmBNk3QWzUVGRNQOCh3YgWRbyXGRERG1g0JXxSSV9/6lRETN0FaBnfuXElE7aKvAzv1LiagdFD7HngRLFomoHbRVYAeS3WwlIiqitkrFEBG1AwZ2IqKSKWwqhitIiYj8FTKwm9rDlIiojAqXionaw5SIqN2lCuwi8ikR2SUiUyIS2ZjGhLh7mBIRtau0M/YXAKwA8LSBscQSZw9TIqJ2liqwq+pLqppr/iPOHqZERO2scDn2OHuYEhG1s8iqGBH5AYB/4/PUV1T17+JeSERWAlgJAH19fbEH6MW2AERE4YzsoCQiTwH4sqrG2hbJ1A5KRETtJO4OSoVLxRARUbi05Y6/ISJ7AfxHAJtFZIuZYRERUaNSrTxV1YcBPGxoLEREZABTMUREJcPATkRUMkaqYhJfVGQ/gNcTvuwkAD/LYDhpteq4gNYdG8eVDMeVTKuOC0g/tlNVtSfqoKYE9kaIyHCcMp+8teq4gNYdG8eVDMeVTKuOC8hvbEzFEBGVDAM7EVHJFCmwr2/2AAK06riA1h0bx5UMx5VMq44LyGlshcmxExFRPEWasRMRUQwtFdjj7sgkIpeIyG4R2SMiq12PLxKRZ0XkFRG5X0SONTSueSLyhH3eJ0Rkrs8xF4jIc67/3heRQfu5e0XkNddzS/Ial33cpOvam1yPN/P9WiIiP7R/3jtF5BrXc0bfr6DfF9fzx9n/+/fY78dC13O32I/vFpGL04yjwbHdJCIv2u/RkyJyqus5359rTuO6XkT2u67/u67nrrN/9q+IyHU5j+tu15h+JCLjrueyfL/uEZG3ReSFgOdFRP7CHvdOEfmI6znz75eqtsx/AP49gMUAngIwEHBMJ4AfAzgNwLEAdgD4sP3cBgDX2l//JYDfMzSu/wlgtf31agB3RRw/D8A7AObY398L4KoM3q9Y4wLwXsDjTXu/APw7AKfbX88H8BaAbtPvV9jvi+uYLwL4S/vrawHcb3/9Yfv44wAsss/TafDnF2dsF7h+j37PGVvYzzWncV0P4Os+r50H4FX737n213PzGpfn+D8AcE/W75d97l8D8BEALwQ8/+sAHoO1fcT5AJ7N8v1qqRm7xtuR6aMA9qjqq6p6BMD3AFwhIgLgQgAP2sd9G8CgoaFdYZ8v7nmvAvCYqh4ydP0gScc1rdnvl6r+SFVxfIevAAAD+0lEQVRfsb/eB+BtAJELLxrg+/sSMt4HASyz358rAHxPVQ+r6msA9tjny21sqrrN9Xu0HcApBq/f8LhCXAzgCVV9R1XfBfAEgEuaNK5PA7jP0LVDqerTsCZzQa4A8Ddq2Q6gW0RORkbvV0sF9ph6Abzp+n6v/dgvARhX1aOex034kKq+BQD2v78ccfy1qP+F+qr9J9jdInJczuM6XkSGRWS7kx5CC71fIvJRWDOwH7seNvV+Bf2++B5jvx8HYL0/cV6bRtLz3wBr1ufw+7nmOa4r7Z/RgyKyIOFrsxwX7JTVIgBbXQ9n9X7FETT2TN6vVN0dGyHpd2QSn8c05PHU44p7Dvs8JwM4G4C7hfEtAP4FVvBaD+BmAHfkOK4+Vd0nIqcB2CoizwP4V5/jmvV+/R8A16nqlP1ww++X3yV8HvP+78zkdyqG2OcXkc8AGADwMdfDdT9XVf2x3+szGNcjAO5T1cMi8gVYf/FcGPO1WY7LcS2AB1V10vVYVu9XHLn+juUe2FX1EylPsRfAAtf3pwDYB6v/QreIHGPPupzHU49LRH4qIier6lt2IHo75FRXA3hYVWuuc79lf3lYRL4F4Mt5jstOdUBVXxVrt6t+AA+hye+XiHwQwGYAf2r/eeqcu+H3y0fQ74vfMXtF5BgAXbD+rI7z2jRinV9EPgHrA/NjqnrYeTzg52oiUEWOS1V/7vr2rwDc5Xrtxz2vfcrAmGKNy+VaAF9yP5Dh+xVH0Ngzeb+KmIr5JwCni1XRcSysH+Amte5EbIOV3waA6wDE3pM1wib7fHHOW5fXs4Obk9ceBOB75zyLcYnIXCeVISInAVgK4MVmv1/2z+5hWHnHBzzPmXy/fH9fQsZ7FYCt9vuzCcC1YlXNLAJwOoB/TDGWxGMTkX4A/xvAclV92/W47881x3Gd7Pp2OYCX7K+3ALjIHt9cABdh9l+vmY7LHttiWDcif+h6LMv3K45NAH7bro45H8ABewKTzfuV1V3iRv4D8BuwPsEOA/gpgC324/MBfN913K8D+BGsT9uvuB4/Ddb/8fYAeADAcYbG9UsAngTwiv3vPPvxAQDfdB23EMAYgA7P67cCeB5WgPoOgA/kNS4Av2Jfe4f97w2t8H4B+AyAGoDnXP8tyeL98vt9gZXaWW5/fbz9v3+P/X6c5nrtV+zX7QZwaQa/81Fj+4H9/wXnPdoU9XPNaVx3AthlX38bgDNcr/2s/V7uAfA7eY7L/n4NgLWe12X9ft0Hq7KrBiuG3QDgCwC+YD8vAL5hj/t5uKr+sni/uPKUiKhkipiKISKiEAzsREQlw8BORFQyDOxERCXDwE5EVDIM7EREJcPATkRUMgzsREQl8/8BK7dA0lu16tYAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x[:,0],y)\n",
"plt.scatter(x[:,0],y_hat);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So far we have specified the *model* (linear regression) and the *evaluation criteria* (or *loss function*). Now we need to handle *optimization*; that is, how do we find the best values for `a`? How do we find the best *fitting* linear regression."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gradient Descent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We would like to find the values of `a` that minimize `mse_loss`.\n",
"\n",
"**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.\n",
"\n",
"Here is gradient descent implemented in [PyTorch](http://pytorch.org/)."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([-1., 1.], requires_grad=True)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = nn.Parameter(a); a"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {
"collapsed": true
},
"outputs": [
{
"data": {
"text/plain": [
"'def update(): \\n y_hat = x@a\\n loss = mse(y, y_hat)\\n if t % 10 == 0: print(loss)\\n loss.backward()\\n with torch.no_grad():\\n a.sub_(lr * a.grad)\\n a.grad.zero_()'"
]
},
"execution_count": 135,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#this was the old update function\n",
"\"\"\"def update(): \n",
" y_hat = x@a\n",
" loss = mse(y, y_hat)\n",
" if t % 10 == 0: print(loss)\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" a.sub_(lr * a.grad)\n",
" a.grad.zero_()\"\"\" "
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"last_step = tensor(0.,0.) #creating new tensor to hold value of previous step\n",
"set_moms = 0.8 #setting momentum value"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"def update():\n",
" global last_step\n",
" y_hat = x@a\n",
" loss = mse(y, y_hat)\n",
" loss.backward()\n",
" with torch.no_grad(): \n",
" if t == 0: moms = 0 #for first epoch, set moms to zero (since there is no previous epoch)\n",
" else: moms = set_moms\n",
" step = (moms*last_step) + ((1-moms)*a.grad)\n",
" a.sub_(lr * step)\n",
" last_step = step\n",
" print(f'Epoch {t}', loss, a.grad, last_step) #printing the loss, a.grad, and last_step values to check they're calculating correctly\n",
" a.grad.zero_()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 tensor(7.5648, grad_fn=<MeanBackward1>) tensor([-2.5998, -2.8958]) tensor([-2.5998, -2.8958])\n",
"Epoch 1 tensor(6.1540, grad_fn=<MeanBackward1>) tensor([-2.4394, -2.3239]) tensor([-2.5678, -2.7814])\n",
"Epoch 2 tensor(4.9780, grad_fn=<MeanBackward1>) tensor([-2.2808, -1.7747]) tensor([-2.5104, -2.5801])\n",
"Epoch 3 tensor(4.0327, grad_fn=<MeanBackward1>) tensor([-2.1253, -1.2657]) tensor([-2.4333, -2.3172])\n",
"Epoch 4 tensor(3.2936, grad_fn=<MeanBackward1>) tensor([-1.9740, -0.8091]) tensor([-2.3415, -2.0156])\n",
"Epoch 5 tensor(2.7254, grad_fn=<MeanBackward1>) tensor([-1.8279, -0.4125]) tensor([-2.2388, -1.6950])\n",
"Epoch 6 tensor(2.2902, grad_fn=<MeanBackward1>) tensor([-1.6876, -0.0797]) tensor([-2.1285, -1.3719])\n",
"Epoch 7 tensor(1.9527, grad_fn=<MeanBackward1>) tensor([-1.5535, 0.1888]) tensor([-2.0135, -1.0598])\n",
"Epoch 8 tensor(1.6837, grad_fn=<MeanBackward1>) tensor([-1.4259, 0.3951]) tensor([-1.8960, -0.7688])\n",
"Epoch 9 tensor(1.4609, grad_fn=<MeanBackward1>) tensor([-1.3052, 0.5436]) tensor([-1.7778, -0.5063])\n",
"Epoch 10 tensor(1.2689, grad_fn=<MeanBackward1>) tensor([-1.1914, 0.6399]) tensor([-1.6605, -0.2771])\n",
"Epoch 11 tensor(1.0984, grad_fn=<MeanBackward1>) tensor([-1.0846, 0.6907]) tensor([-1.5454, -0.0835])\n",
"Epoch 12 tensor(0.9443, grad_fn=<MeanBackward1>) tensor([-0.9847, 0.7031]) tensor([-1.4332, 0.0738])\n",
"Epoch 13 tensor(0.8047, grad_fn=<MeanBackward1>) tensor([-0.8916, 0.6843]) tensor([-1.3249, 0.1959])\n",
"Epoch 14 tensor(0.6794, grad_fn=<MeanBackward1>) tensor([-0.8052, 0.6415]) tensor([-1.2209, 0.2850])\n",
"Epoch 15 tensor(0.5685, grad_fn=<MeanBackward1>) tensor([-0.7253, 0.5811]) tensor([-1.1218, 0.3442])\n",
"Epoch 16 tensor(0.4725, grad_fn=<MeanBackward1>) tensor([-0.6516, 0.5091]) tensor([-1.0278, 0.3772])\n",
"Epoch 17 tensor(0.3913, grad_fn=<MeanBackward1>) tensor([-0.5840, 0.4308]) tensor([-0.9390, 0.3879])\n",
"Epoch 18 tensor(0.3242, grad_fn=<MeanBackward1>) tensor([-0.5220, 0.3506]) tensor([-0.8556, 0.3805])\n",
"Epoch 19 tensor(0.2701, grad_fn=<MeanBackward1>) tensor([-0.4655, 0.2721]) tensor([-0.7776, 0.3588])\n",
"Epoch 20 tensor(0.2275, grad_fn=<MeanBackward1>) tensor([-0.4141, 0.1982]) tensor([-0.7049, 0.3267])\n",
"Epoch 21 tensor(0.1946, grad_fn=<MeanBackward1>) tensor([-0.3675, 0.1309]) tensor([-0.6374, 0.2875])\n",
"Epoch 22 tensor(0.1696, grad_fn=<MeanBackward1>) tensor([-0.3254, 0.0716]) tensor([-0.5750, 0.2443])\n",
"Epoch 23 tensor(0.1508, grad_fn=<MeanBackward1>) tensor([-0.2875, 0.0212]) tensor([-0.5175, 0.1997])\n",
"Epoch 24 tensor(0.1368, grad_fn=<MeanBackward1>) tensor([-0.2534, -0.0202]) tensor([-0.4647, 0.1557])\n",
"Epoch 25 tensor(0.1263, grad_fn=<MeanBackward1>) tensor([-0.2229, -0.0527]) tensor([-0.4163, 0.1140])\n",
"Epoch 26 tensor(0.1183, grad_fn=<MeanBackward1>) tensor([-0.1956, -0.0766]) tensor([-0.3722, 0.0759])\n",
"Epoch 27 tensor(0.1121, grad_fn=<MeanBackward1>) tensor([-0.1712, -0.0928]) tensor([-0.3320, 0.0422])\n",
"Epoch 28 tensor(0.1072, grad_fn=<MeanBackward1>) tensor([-0.1496, -0.1022]) tensor([-0.2955, 0.0133])\n",
"Epoch 29 tensor(0.1032, grad_fn=<MeanBackward1>) tensor([-0.1304, -0.1057]) tensor([-0.2625, -0.0105])\n",
"Epoch 30 tensor(0.0999, grad_fn=<MeanBackward1>) tensor([-0.1134, -0.1043]) tensor([-0.2327, -0.0293])\n",
"Epoch 31 tensor(0.0972, grad_fn=<MeanBackward1>) tensor([-0.0984, -0.0991]) tensor([-0.2058, -0.0432])\n",
"Epoch 32 tensor(0.0949, grad_fn=<MeanBackward1>) tensor([-0.0852, -0.0910]) tensor([-0.1817, -0.0528])\n",
"Epoch 33 tensor(0.0930, grad_fn=<MeanBackward1>) tensor([-0.0736, -0.0810]) tensor([-0.1601, -0.0584])\n",
"Epoch 34 tensor(0.0914, grad_fn=<MeanBackward1>) tensor([-0.0634, -0.0697]) tensor([-0.1407, -0.0607])\n",
"Epoch 35 tensor(0.0902, grad_fn=<MeanBackward1>) tensor([-0.0544, -0.0580]) tensor([-0.1235, -0.0602])\n",
"Epoch 36 tensor(0.0893, grad_fn=<MeanBackward1>) tensor([-0.0466, -0.0463]) tensor([-0.1081, -0.0574])\n",
"Epoch 37 tensor(0.0886, grad_fn=<MeanBackward1>) tensor([-0.0397, -0.0351]) tensor([-0.0944, -0.0529])\n",
"Epoch 38 tensor(0.0881, grad_fn=<MeanBackward1>) tensor([-0.0338, -0.0248]) tensor([-0.0823, -0.0473])\n",
"Epoch 39 tensor(0.0877, grad_fn=<MeanBackward1>) tensor([-0.0286, -0.0156]) tensor([-0.0716, -0.0410])\n",
"Epoch 40 tensor(0.0875, grad_fn=<MeanBackward1>) tensor([-0.0240, -0.0076]) tensor([-0.0621, -0.0343])\n",
"Epoch 41 tensor(0.0873, grad_fn=<MeanBackward1>) tensor([-0.0201, -0.0009]) tensor([-0.0537, -0.0276])\n",
"Epoch 42 tensor(0.0872, grad_fn=<MeanBackward1>) tensor([-0.0167, 0.0045]) tensor([-0.0463, -0.0212])\n",
"Epoch 43 tensor(0.0872, grad_fn=<MeanBackward1>) tensor([-0.0138, 0.0086]) tensor([-0.0398, -0.0152])\n",
"Epoch 44 tensor(0.0871, grad_fn=<MeanBackward1>) tensor([-0.0112, 0.0115]) tensor([-0.0341, -0.0099])\n",
"Epoch 45 tensor(0.0871, grad_fn=<MeanBackward1>) tensor([-0.0091, 0.0134]) tensor([-0.0291, -0.0052])\n",
"Epoch 46 tensor(0.0871, grad_fn=<MeanBackward1>) tensor([-0.0072, 0.0144]) tensor([-0.0247, -0.0013])\n",
"Epoch 47 tensor(0.0871, grad_fn=<MeanBackward1>) tensor([-0.0056, 0.0146]) tensor([-0.0209, 0.0019])\n",
"Epoch 48 tensor(0.0871, grad_fn=<MeanBackward1>) tensor([-0.0042, 0.0141]) tensor([-0.0176, 0.0043])\n",
"Epoch 49 tensor(0.0871, grad_fn=<MeanBackward1>) tensor([-0.0031, 0.0132]) tensor([-0.0147, 0.0061])\n"
]
}
],
"source": [
"lr = 1e-1\n",
"for t in range(50): update()"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x[:,0],y)\n",
"plt.scatter(x[:,0],x@a);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment