Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Last active December 17, 2018 18:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bigsnarfdude/0580b293382d048496a7895c9ffc3964 to your computer and use it in GitHub Desktop.
Save bigsnarfdude/0580b293382d048496a7895c9ffc3964 to your computer and use it in GitHub Desktop.
From Appendix B in the paper Neural ODE Solver Implementation using autograd https://arxiv.org/abs/1806.07366
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Autograd: Automatic Differentiation\n",
"===================================\n",
"\n",
"Central to all neural networks in PyTorch is the ``autograd`` package.\n",
"Let’s first briefly visit this, and we will then go to training our\n",
"first neural network.\n",
"\n",
"\n",
"The ``autograd`` package provides automatic differentiation for all operations\n",
"on Tensors. It is a define-by-run framework, which means that your backprop is\n",
"defined by how your code is run, and that every single iteration can be\n",
"different.\n",
"\n",
"Let us see this in more simple terms with some examples.\n",
"\n",
"Tensor\n",
"--------\n",
"\n",
"``torch.Tensor`` is the central class of the package. If you set its attribute\n",
"``.requires_grad`` as ``True``, it starts to track all operations on it. When\n",
"you finish your computation you can call ``.backward()`` and have all the\n",
"gradients computed automatically. The gradient for this tensor will be\n",
"accumulated into ``.grad`` attribute.\n",
"\n",
"To stop a tensor from tracking history, you can call ``.detach()`` to detach\n",
"it from the computation history, and to prevent future computation from being\n",
"tracked.\n",
"\n",
"To prevent tracking history (and using memory), you can also wrap the code block\n",
"in ``with torch.no_grad():``. This can be particularly helpful when evaluating a\n",
"model because the model may have trainable parameters with `requires_grad=True`,\n",
"but for which we don't need the gradients.\n",
"\n",
"There’s one more class which is very important for autograd\n",
"implementation - a ``Function``.\n",
"\n",
"``Tensor`` and ``Function`` are interconnected and build up an acyclic\n",
"graph, that encodes a complete history of computation. Each tensor has\n",
"a ``.grad_fn`` attribute that references a ``Function`` that has created\n",
"the ``Tensor`` (except for Tensors created by the user - their\n",
"``grad_fn is None``).\n",
"\n",
"If you want to compute the derivatives, you can call ``.backward()`` on\n",
"a ``Tensor``. If ``Tensor`` is a scalar (i.e. it holds a one element\n",
"data), you don’t need to specify any arguments to ``backward()``,\n",
"however if it has more elements, you need to specify a ``gradient``\n",
"argument that is a tensor of matching shape.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.0.0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a tensor and set requires_grad=True to track computation with it\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1., 1.],\n",
" [1., 1.]], requires_grad=True)\n"
]
}
],
"source": [
"x = torch.ones(2, 2, requires_grad=True)\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Do an operation of tensor:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[3., 3.],\n",
" [3., 3.]], grad_fn=<AddBackward0>)\n"
]
}
],
"source": [
"y = x + 2\n",
"print(y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``y`` was created as a result of an operation, so it has a ``grad_fn``.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<AddBackward0 object at 0x000001E23DD5C438>\n"
]
}
],
"source": [
"print(y.grad_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Do more operations on y\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[27., 27.],\n",
" [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward1>)\n"
]
}
],
"source": [
"z = y * y * 3\n",
"out = z.mean()\n",
"\n",
"print(z, out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``.requires_grad_( ... )`` changes an existing Tensor's ``requires_grad``\n",
"flag in-place. The input flag defaults to ``False`` if not given.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n",
"True\n",
"<SumBackward0 object at 0x000001E23FD2C438>\n"
]
}
],
"source": [
"a = torch.randn(2, 2)\n",
"a = ((a * 3) / (a - 1))\n",
"print(a.requires_grad)\n",
"a.requires_grad_(True)\n",
"print(a.requires_grad)\n",
"b = (a * a).sum()\n",
"print(b.grad_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Gradients\n",
"---------\n",
"Let's backprop now\n",
"Because ``out`` contains a single scalar, ``out.backward()`` is\n",
"equivalent to ``out.backward(torch.tensor(1.))``.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"out.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"print gradients d(out)/dx\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[4.5000, 4.5000],\n",
" [4.5000, 4.5000]])\n"
]
}
],
"source": [
"print(x.grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should have got a matrix of ``4.5``. Let’s call the ``out``\n",
"*Tensor* “$o$”.\n",
"We have that $o = \\frac{1}{4}\\sum_i z_i$,\n",
"$z_i = 3(x_i+2)^2$ and $z_i\\bigr\\rvert_{x_i=1} = 27$.\n",
"Therefore,\n",
"$\\frac{\\partial o}{\\partial x_i} = \\frac{3}{2}(x_i+2)$, hence\n",
"$\\frac{\\partial o}{\\partial x_i}\\bigr\\rvert_{x_i=1} = \\frac{9}{2} = 4.5$.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mathematically, if you have a vector valued function $\\vec{y}=f(\\vec{x})$,\n",
"then the gradient of $\\vec{y}$ with respect to $\\vec{x}$\n",
"is a Jacobian matrix:\n",
"\n",
"\\begin{align}J=\\left(\\begin{array}{ccc}\n",
" \\frac{\\partial y_{1}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{1}}\\\\\n",
" \\vdots & \\ddots & \\vdots\\\\\n",
" \\frac{\\partial y_{1}}{\\partial x_{n}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{n}}\n",
" \\end{array}\\right)\\end{align}\n",
"\n",
"Generally speaking, ``torch.autograd`` is an engine for computing\n",
"Jacobian-vector product. That is, given any vector\n",
"$v=\\left(\\begin{array}{cccc} v_{1} & v_{2} & \\cdots & v_{m}\\end{array}\\right)^{T}$,\n",
"compute the product $J\\cdot v$. If $v$ happens to be\n",
"the gradient of a scalar function $l=g\\left(\\vec{y}\\right)$,\n",
"that is,\n",
"$v=\\left(\\begin{array}{ccc}\\frac{\\partial l}{\\partial y_{1}} & \\cdots & \\frac{\\partial l}{\\partial y_{m}}\\end{array}\\right)^{T}$,\n",
"then by the chain rule, the Jacobian-vector product would be the\n",
"gradient of $l$ with respect to $\\vec{x}$:\n",
"\n",
"\\begin{align}J\\cdot v=\\left(\\begin{array}{ccc}\n",
" \\frac{\\partial y_{1}}{\\partial x_{1}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{1}}\\\\\n",
" \\vdots & \\ddots & \\vdots\\\\\n",
" \\frac{\\partial y_{1}}{\\partial x_{n}} & \\cdots & \\frac{\\partial y_{m}}{\\partial x_{n}}\n",
" \\end{array}\\right)\\left(\\begin{array}{c}\n",
" \\frac{\\partial l}{\\partial y_{1}}\\\\\n",
" \\vdots\\\\\n",
" \\frac{\\partial l}{\\partial y_{m}}\n",
" \\end{array}\\right)=\\left(\\begin{array}{c}\n",
" \\frac{\\partial l}{\\partial x_{1}}\\\\\n",
" \\vdots\\\\\n",
" \\frac{\\partial l}{\\partial x_{n}}\n",
" \\end{array}\\right)\\end{align}\n",
"\n",
"This characteristic of Jacobian-vector product makes it very\n",
"convenient to feed external gradients into a model that has\n",
"non-scalar output.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's take a look at an example of Jacobian-vector product:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0.8672, -0.8962, 1.2666], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(1.7775)\n",
"tensor([ 1.7344, -1.7924, 2.5333], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(3.5550)\n",
"tensor([ 3.4689, -3.5847, 5.0666], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(7.1101)\n",
"tensor([ 6.9377, -7.1695, 10.1331], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(14.2202)\n",
"tensor([ 13.8754, -14.3389, 20.2663], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(28.4403)\n",
"tensor([ 27.7508, -28.6778, 40.5325], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(56.8807)\n",
"tensor([ 55.5016, -57.3557, 81.0650], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(113.7614)\n",
"tensor([ 111.0033, -114.7114, 162.1301], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(227.5227)\n",
"tensor([ 222.0066, -229.4228, 324.2601], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(455.0455)\n",
"tensor([ 444.0132, -458.8456, 648.5203], grad_fn=<MulBackward0>)\n",
"None\n",
"tensor(910.0909)\n",
"tensor([ 888.0263, -917.6912, 1297.0405], grad_fn=<MulBackward0>)\n"
]
}
],
"source": [
"x = torch.randn(3, requires_grad=True)\n",
"\n",
"y = x * 2\n",
"while y.data.norm() < 1000:\n",
" print(y)\n",
" print(y.grad)\n",
" print(y.data.norm())\n",
" y = y * 2\n",
"\n",
"print(y)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now in this case ``y`` is no longer a scalar. ``torch.autograd``\n",
"could not compute the full Jacobian directly, but if we just\n",
"want the Jacobian-vector product, simply pass the vector to\n",
"``backward`` as argument:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([2.0480e+02, 2.0480e+03, 2.0480e-01])\n"
]
}
],
"source": [
"v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)\n",
"\n",
"y.backward(v)\n",
"\n",
"print(x.grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also stop autograd from tracking history on Tensors\n",
"with ``.requires_grad=True`` by wrapping the code block in\n",
"``with torch.no_grad()``:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(x.requires_grad)\n",
"print((x ** 2).requires_grad)\n",
"\n",
"with torch.no_grad():\n",
"\tprint((x ** 2).requires_grad)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"import autograd.numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from autograd import grad\n",
"\n",
"def fun(x):\n",
" return np.sin(x)\n",
"\n",
"d_fun = grad(fun) # First derivative\n",
"dd_fun = grad(d_fun) # Second derivative\n",
"\n",
"x = np.linspace(-10, 10, 100)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "matplotlib does not support generators as input",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\units.py\u001b[0m in \u001b[0;36mget_converter\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 168\u001b[0m \u001b[1;31m# get_converter\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 169\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mall\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mxravel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 170\u001b[0m \u001b[1;31m# some elements are not masked\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'mask'",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-37-60437b9da67f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mbb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0md_fun\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mcc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdd_fun\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maa\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\pyplot.py\u001b[0m in \u001b[0;36mplot\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 3356\u001b[0m mplDeprecation)\n\u001b[0;32m 3357\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 3358\u001b[1;33m \u001b[0mret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0max\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3359\u001b[0m \u001b[1;32mfinally\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3360\u001b[0m \u001b[0max\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_hold\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mwashold\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\__init__.py\u001b[0m in \u001b[0;36minner\u001b[1;34m(ax, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1853\u001b[0m \u001b[1;34m\"the Matplotlib list!)\"\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mlabel_namer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1854\u001b[0m RuntimeWarning, stacklevel=2)\n\u001b[1;32m-> 1855\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0max\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1856\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1857\u001b[0m inner.__doc__ = _add_data_doc(inner.__doc__,\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\axes\\_axes.py\u001b[0m in \u001b[0;36mplot\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1525\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcbook\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormalize_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_alias_map\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1526\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1527\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_lines\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1528\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_line\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mline\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1529\u001b[0m \u001b[0mlines\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mline\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\axes\\_base.py\u001b[0m in \u001b[0;36m_grab_next_args\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 404\u001b[0m \u001b[0mthis\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 405\u001b[0m \u001b[0margs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 406\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mseg\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_plot_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mthis\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 407\u001b[0m \u001b[1;32myield\u001b[0m \u001b[0mseg\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 408\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\axes\\_base.py\u001b[0m in \u001b[0;36m_plot_args\u001b[1;34m(self, tup, kwargs)\u001b[0m\n\u001b[0;32m 381\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mindex_of\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 382\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 383\u001b[1;33m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_xy_from_xy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 384\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 385\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcommand\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'plot'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\axes\\_base.py\u001b[0m in \u001b[0;36m_xy_from_xy\u001b[1;34m(self, x, y)\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mxaxis\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0myaxis\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[0mbx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mxaxis\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_units\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 216\u001b[1;33m \u001b[0mby\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0myaxis\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_units\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 217\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 218\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcommand\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;34m'plot'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\axis.py\u001b[0m in \u001b[0;36mupdate_units\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 1461\u001b[0m \"\"\"\n\u001b[0;32m 1462\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1463\u001b[1;33m \u001b[0mconverter\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmunits\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mregistry\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_converter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1464\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mconverter\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1465\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\units.py\u001b[0m in \u001b[0;36mget_converter\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 180\u001b[0m if (not isinstance(next_item, np.ndarray) or\n\u001b[0;32m 181\u001b[0m next_item.shape != x.shape):\n\u001b[1;32m--> 182\u001b[1;33m \u001b[0mconverter\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_converter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_item\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 183\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mconverter\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 184\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\units.py\u001b[0m in \u001b[0;36mget_converter\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 186\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mconverter\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 187\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 188\u001b[1;33m \u001b[0mthisx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msafe_first_element\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 189\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 190\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\matplotlib\\cbook\\__init__.py\u001b[0m in \u001b[0;36msafe_first_element\u001b[1;34m(obj)\u001b[0m\n\u001b[0;32m 2355\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2356\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2357\u001b[1;33m raise RuntimeError(\"matplotlib does not support generators \"\n\u001b[0m\u001b[0;32m 2358\u001b[0m \"as input\")\n\u001b[0;32m 2359\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mnext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mRuntimeError\u001b[0m: matplotlib does not support generators as input"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADYBJREFUeJzt3HGI33d9x/Hny8ROprWO5QRJou1YuhrKoO7oOoRZ0Y20fyT/FEmguEppwK0OZhE6HCr1rylDELJptolT0Fr9Qw+J5A9X6RAjudJZmpTALTpzROhZu/5TtGZ774/fT++4XHLf3v3uLt77+YDA7/v7fX6/e+fD3TO/fH/3+6WqkCRtf6/a6gEkSZvD4EtSEwZfkpow+JLUhMGXpCYMviQ1sWrwk3wuyXNJnrnC7Uny6SRzSZ5O8rbJjylJWq8hz/A/Dxy4yu13AfvGf44C/7T+sSRJk7Zq8KvqCeBnV1lyCPhCjZwC3pDkTZMaUJI0GTsn8Bi7gQtLjufH1/1k+cIkRxn9L4DXvva1f3TLLbdM4MtLUh9PPvnkT6tqai33nUTws8J1K35eQ1UdB44DTE9P1+zs7AS+vCT1keS/13rfSfyWzjywd8nxHuDiBB5XkjRBkwj+DPDe8W/r3AG8WFWXnc6RJG2tVU/pJPkycCewK8k88FHg1QBV9RngBHA3MAe8BLxvo4aVJK3dqsGvqiOr3F7AX01sIknShvCdtpLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDUxKPhJDiQ5l2QuycMr3P7mJI8neSrJ00nunvyokqT1WDX4SXYAx4C7gP3AkST7ly37O+CxqroNOAz846QHlSStz5Bn+LcDc1V1vqpeBh4FDi1bU8Drx5dvAC5ObkRJ0iQMCf5u4MKS4/nxdUt9DLg3yTxwAvjASg+U5GiS2SSzCwsLaxhXkrRWQ4KfFa6rZcdHgM9X1R7gbuCLSS577Ko6XlXTVTU9NTX1yqeVJK3ZkODPA3uXHO/h8lM29wOPAVTV94DXALsmMaAkaTKGBP80sC/JTUmuY/Si7MyyNT8G3gWQ5K2Mgu85G0m6hqwa/Kq6BDwInASeZfTbOGeSPJLk4HjZQ8ADSX4AfBm4r6qWn/aRJG2hnUMWVdUJRi/GLr3uI0sunwXePtnRJEmT5DttJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNDAp+kgNJziWZS/LwFda8J8nZJGeSfGmyY0qS1mvnaguS7ACOAX8GzAOnk8xU1dkla/YBfwu8vapeSPLGjRpYkrQ2Q57h3w7MVdX5qnoZeBQ4tGzNA8CxqnoBoKqem+yYkqT1GhL83cCFJcfz4+uWuhm4Ocl3k5xKcmClB0pyNMlsktmFhYW1TSxJWpMhwc8K19Wy453APuBO4AjwL0necNmdqo5X1XRVTU9NTb3SWSVJ6zAk+PPA3iXHe4CLK6z5RlX9sqp+CJxj9A+AJOkaMST4p4F9SW5Kch1wGJhZtubrwDsBkuxidIrn/CQHlSStz6rBr6pLwIPASeBZ4LGqOpPkkSQHx8tOAs8nOQs8Dnyoqp7fqKElSa9cqpafjt8c09PTNTs7uyVfW5J+UyV5sqqm13Jf32krSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSE4OCn+RAknNJ5pI8fJV19ySpJNOTG1GSNAmrBj/JDuAYcBewHziSZP8K664H/hr4/qSHlCSt35Bn+LcDc1V1vqpeBh4FDq2w7uPAJ4CfT3A+SdKEDAn+buDCkuP58XW/luQ2YG9VffNqD5TkaJLZJLMLCwuveFhJ0toNCX5WuK5+fWPyKuBTwEOrPVBVHa+q6aqanpqaGj6lJGndhgR/Hti75HgPcHHJ8fXArcB3kvwIuAOY8YVbSbq2DAn+aWBfkpuSXAccBmZ+dWNVvVhVu6rqxqq6ETgFHKyq2Q2ZWJK0JqsGv6ouAQ8CJ4Fngceq6kySR5Ic3OgBJUmTsXPIoqo6AZxYdt1HrrD2zvWPJUmaNN9pK0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqYlDwkxxIci7JXJKHV7j9g0nOJnk6ybeTvGXyo0qS1mPV4CfZARwD7gL2A0eS7F+27Clguqr+EPga8IlJDypJWp8hz/BvB+aq6nxVvQw8ChxauqCqHq+ql8aHp4A9kx1TkrReQ4K/G7iw5Hh+fN2V3A98a6UbkhxNMptkdmFhYfiUkqR1GxL8rHBdrbgwuReYBj650u1VdbyqpqtqempqaviUkqR12zlgzTywd8nxHuDi8kVJ3g18GHhHVf1iMuNJkiZlyDP808C+JDcluQ44DMwsXZDkNuCzwMGqem7yY0qS1mvV4FfVJeBB4CTwLPBYVZ1J8kiSg+NlnwReB3w1yX8mmbnCw0mStsiQUzpU1QngxLLrPrLk8rsnPJckacJ8p60kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNDAp+kgNJziWZS/LwCrf/VpKvjG//fpIbJz2oJGl9Vg1+kh3AMeAuYD9wJMn+ZcvuB16oqt8HPgX8/aQHlSStz5Bn+LcDc1V1vqpeBh4FDi1bcwj4t/HlrwHvSpLJjSlJWq+dA9bsBi4sOZ4H/vhKa6rqUpIXgd8Ffrp0UZKjwNHx4S+SPLOWobehXSzbq8bci0XuxSL3YtEfrPWOQ4K/0jP1WsMaquo4cBwgyWxVTQ/4+tuee7HIvVjkXixyLxYlmV3rfYec0pkH9i453gNcvNKaJDuBG4CfrXUoSdLkDQn+aWBfkpuSXAccBmaWrZkB/mJ8+R7g36vqsmf4kqSts+opnfE5+QeBk8AO4HNVdSbJI8BsVc0A/wp8Mckco2f2hwd87ePrmHu7cS8WuReL3ItF7sWiNe9FfCIuST34TltJasLgS1ITGx58P5Zh0YC9+GCSs0meTvLtJG/Zijk3w2p7sWTdPUkqybb9lbwhe5HkPePvjTNJvrTZM26WAT8jb07yeJKnxj8nd2/FnBstyeeSPHel9ypl5NPjfXo6ydsGPXBVbdgfRi/y/hfwe8B1wA+A/cvW/CXwmfHlw8BXNnKmrfozcC/eCfz2+PL7O+/FeN31wBPAKWB6q+fewu+LfcBTwO+Mj9+41XNv4V4cB94/vrwf+NFWz71Be/GnwNuAZ65w+93Atxi9B+oO4PtDHnejn+H7sQyLVt2Lqnq8ql4aH55i9J6H7WjI9wXAx4FPAD/fzOE22ZC9eAA4VlUvAFTVc5s842YZshcFvH58+QYuf0/QtlBVT3D19zIdAr5QI6eANyR502qPu9HBX+ljGXZfaU1VXQJ+9bEM282QvVjqfkb/gm9Hq+5FktuAvVX1zc0cbAsM+b64Gbg5yXeTnEpyYNOm21xD9uJjwL1J5oETwAc2Z7RrzivtCTDsoxXWY2Ify7ANDP57JrkXmAbesaETbZ2r7kWSVzH61NX7NmugLTTk+2Ino9M6dzL6X99/JLm1qv5ng2fbbEP24gjw+ar6hyR/wuj9P7dW1f9t/HjXlDV1c6Of4fuxDIuG7AVJ3g18GDhYVb/YpNk222p7cT1wK/CdJD9idI5yZpu+cDv0Z+QbVfXLqvohcI7RPwDbzZC9uB94DKCqvge8htEHq3UzqCfLbXTw/ViGRavuxfg0xmcZxX67nqeFVfaiql6sql1VdWNV3cjo9YyDVbXmD426hg35Gfk6oxf0SbKL0Sme85s65eYYshc/Bt4FkOStjIK/sKlTXhtmgPeOf1vnDuDFqvrJanfa0FM6tXEfy/AbZ+BefBJ4HfDV8evWP66qg1s29AYZuBctDNyLk8CfJzkL/C/woap6fuum3hgD9+Ih4J+T/A2jUxj3bccniEm+zOgU3q7x6xUfBV4NUFWfYfT6xd3AHPAS8L5Bj7sN90qStALfaStJTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ18f+GmWq6NWLIwgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"aa = map(fun, x)\n",
"bb = map(d_fun, x)\n",
"cc = map(dd_fun, x)\n",
"plt.plot(x, aa, x, bb, x, cc)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\vohprecio\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"35632530.0\n",
"34166744.0\n",
"37059890.0\n",
"37217424.0\n",
"30241078.0\n",
"19255156.0\n",
"9918273.0\n",
"4791853.0\n",
"2483867.2\n",
"1507002.4\n",
"1056859.8\n",
"816242.9\n",
"664323.9\n",
"555862.06\n",
"472214.22\n",
"404901.0\n",
"349502.06\n",
"303349.28\n",
"264535.3\n",
"231685.06\n",
"203718.7\n",
"179735.56\n",
"159073.28\n",
"141187.3\n",
"125644.38\n",
"112106.195\n",
"100336.96\n",
"89999.05\n",
"80901.375\n",
"72865.91\n",
"65750.805\n",
"59429.62\n",
"53805.305\n",
"48792.195\n",
"44316.61\n",
"40311.895\n",
"36718.723\n",
"33505.617\n",
"30616.807\n",
"28014.102\n",
"25665.594\n",
"23543.066\n",
"21622.166\n",
"19877.871\n",
"18292.67\n",
"16851.28\n",
"15538.994\n",
"14341.082\n",
"13246.801\n",
"12245.09\n",
"11327.864\n",
"10488.254\n",
"9718.745\n",
"9011.977\n",
"8362.131\n",
"7764.3496\n",
"7214.0176\n",
"6706.7764\n",
"6238.8716\n",
"5807.46\n",
"5408.708\n",
"5040.1167\n",
"4698.984\n",
"4383.2764\n",
"4090.8525\n",
"3819.8293\n",
"3568.6116\n",
"3335.4834\n",
"3118.8237\n",
"2917.6167\n",
"2730.524\n",
"2556.3657\n",
"2394.3887\n",
"2243.5754\n",
"2103.1663\n",
"1972.2114\n",
"1850.0674\n",
"1736.091\n",
"1629.8108\n",
"1530.5598\n",
"1437.8413\n",
"1351.1545\n",
"1270.1364\n",
"1194.3129\n",
"1123.4174\n",
"1056.9854\n",
"994.8311\n",
"936.62683\n",
"882.04333\n",
"830.8779\n",
"782.8898\n",
"737.8443\n",
"695.5705\n",
"655.8908\n",
"618.63513\n",
"583.61523\n",
"550.7021\n",
"519.7772\n",
"490.68488\n",
"463.32205\n",
"437.5882\n",
"413.36047\n",
"390.552\n",
"369.07935\n",
"348.85327\n",
"329.79077\n",
"311.8376\n",
"294.9079\n",
"278.95068\n",
"263.90088\n",
"249.71323\n",
"236.31973\n",
"223.6864\n",
"211.75989\n",
"200.50742\n",
"189.88086\n",
"179.84396\n",
"170.36368\n",
"161.40729\n",
"152.94498\n",
"144.94736\n",
"137.38744\n",
"130.23933\n",
"123.47879\n",
"117.085495\n",
"111.03822\n",
"105.31457\n",
"99.89987\n",
"94.77915\n",
"89.92987\n",
"85.338295\n",
"80.989685\n",
"76.87491\n",
"72.97392\n",
"69.28178\n",
"65.783005\n",
"62.469368\n",
"59.32553\n",
"56.347363\n",
"53.52438\n",
"50.850094\n",
"48.311546\n",
"45.906017\n",
"43.623707\n",
"41.4588\n",
"39.40451\n",
"37.45754\n",
"35.60898\n",
"33.854774\n",
"32.188683\n",
"30.60874\n",
"29.10807\n",
"27.684772\n",
"26.331486\n",
"25.047714\n",
"23.827518\n",
"22.669542\n",
"21.569008\n",
"20.523762\n",
"19.530045\n",
"18.586739\n",
"17.689764\n",
"16.83685\n",
"16.026623\n",
"15.257511\n",
"14.525566\n",
"13.829012\n",
"13.16762\n",
"12.538312\n",
"11.940166\n",
"11.37129\n",
"10.829909\n",
"10.315443\n",
"9.825314\n",
"9.359444\n",
"8.916112\n",
"8.494016\n",
"8.092807\n",
"7.710865\n",
"7.3474393\n",
"7.001525\n",
"6.6722794\n",
"6.3586893\n",
"6.060279\n",
"5.7763457\n",
"5.5055294\n",
"5.2478466\n",
"5.00281\n",
"4.768996\n",
"4.546772\n",
"4.3347726\n",
"4.133177\n",
"3.9408875\n",
"3.757871\n",
"3.5835638\n",
"3.4171364\n",
"3.2588482\n",
"3.107923\n",
"2.9642096\n",
"2.8272033\n",
"2.6966646\n",
"2.5724187\n",
"2.4538507\n",
"2.3407292\n",
"2.232832\n",
"2.1302705\n",
"2.032448\n",
"1.9390846\n",
"1.8500855\n",
"1.7653635\n",
"1.6844192\n",
"1.607228\n",
"1.5337617\n",
"1.4637802\n",
"1.3968883\n",
"1.3329959\n",
"1.2721562\n",
"1.2141538\n",
"1.1588588\n",
"1.106006\n",
"1.0557623\n",
"1.0077819\n",
"0.9619578\n",
"0.9182838\n",
"0.8765352\n",
"0.8368921\n",
"0.79887617\n",
"0.76260626\n",
"0.7281928\n",
"0.69511694\n",
"0.66376936\n",
"0.6337569\n",
"0.60503656\n",
"0.5776768\n",
"0.55165267\n",
"0.5267988\n",
"0.5030663\n",
"0.48037484\n",
"0.45875597\n",
"0.43814915\n",
"0.41834652\n",
"0.39955407\n",
"0.3816045\n",
"0.3644629\n",
"0.34806794\n",
"0.3324594\n",
"0.31749597\n",
"0.3032619\n",
"0.2896499\n",
"0.2766657\n",
"0.26426187\n",
"0.2524106\n",
"0.24114057\n",
"0.23031105\n",
"0.22004148\n",
"0.21019603\n",
"0.20078641\n",
"0.19186345\n",
"0.1832726\n",
"0.17512226\n",
"0.1672943\n",
"0.15982454\n",
"0.15268345\n",
"0.14587705\n",
"0.13937628\n",
"0.13317493\n",
"0.12723255\n",
"0.12157785\n",
"0.116137445\n",
"0.11103091\n",
"0.10608796\n",
"0.10134934\n",
"0.096845716\n",
"0.09255558\n",
"0.08844006\n",
"0.08449518\n",
"0.080758534\n",
"0.077191174\n",
"0.07376273\n",
"0.07049048\n",
"0.06736806\n",
"0.06436903\n",
"0.061544143\n",
"0.058808126\n",
"0.056199703\n",
"0.053711433\n",
"0.05134546\n",
"0.049073946\n",
"0.046905868\n",
"0.04483395\n",
"0.042870782\n",
"0.040966555\n",
"0.0391574\n",
"0.037439447\n",
"0.035787366\n",
"0.034214903\n",
"0.032708243\n",
"0.03127743\n",
"0.02990238\n",
"0.02859331\n",
"0.027315525\n",
"0.026126482\n",
"0.024998123\n",
"0.023912752\n",
"0.02285477\n",
"0.021850456\n",
"0.020899475\n",
"0.01998027\n",
"0.01911814\n",
"0.018292094\n",
"0.017489234\n",
"0.016738568\n",
"0.016006611\n",
"0.015314603\n",
"0.014655284\n",
"0.014027941\n",
"0.013423682\n",
"0.0128369415\n",
"0.012293164\n",
"0.011760444\n",
"0.01126135\n",
"0.01076903\n",
"0.010314537\n",
"0.009871031\n",
"0.009450923\n",
"0.0090458365\n",
"0.008660833\n",
"0.008291104\n",
"0.007941701\n",
"0.007604987\n",
"0.0072935773\n",
"0.0069828364\n",
"0.0066894656\n",
"0.006408456\n",
"0.0061393823\n",
"0.0058761695\n",
"0.005631969\n",
"0.0053990455\n",
"0.005176464\n",
"0.0049613034\n",
"0.004754216\n",
"0.0045597567\n",
"0.00437132\n",
"0.004196186\n",
"0.004025424\n",
"0.0038602178\n",
"0.0036979723\n",
"0.003553073\n",
"0.00340943\n",
"0.0032753963\n",
"0.0031435248\n",
"0.0030196323\n",
"0.0029005534\n",
"0.0027837413\n",
"0.0026735407\n",
"0.0025679425\n",
"0.0024658751\n",
"0.0023711955\n",
"0.0022783037\n",
"0.0021896916\n",
"0.0021063932\n",
"0.0020265665\n",
"0.0019517792\n",
"0.0018736036\n",
"0.0018061891\n",
"0.0017405101\n",
"0.0016761564\n",
"0.001612894\n",
"0.0015550405\n",
"0.0014976561\n",
"0.0014448334\n",
"0.0013956067\n",
"0.0013455102\n",
"0.0012995332\n",
"0.0012531722\n",
"0.001210189\n",
"0.0011679351\n",
"0.0011284054\n",
"0.001089638\n",
"0.0010528808\n",
"0.0010150422\n",
"0.0009816474\n",
"0.00094942603\n",
"0.000918931\n",
"0.0008890177\n",
"0.00085978064\n",
"0.0008319805\n",
"0.0008041366\n",
"0.0007762797\n",
"0.0007538185\n",
"0.00073180144\n",
"0.00070841244\n",
"0.00068493537\n",
"0.00066376565\n",
"0.0006444861\n",
"0.0006249069\n",
"0.0006056711\n",
"0.00058749976\n",
"0.00056944974\n",
"0.0005524007\n",
"0.0005361368\n",
"0.0005208185\n",
"0.00050469564\n",
"0.0004894966\n",
"0.0004758743\n",
"0.00046147133\n",
"0.00044808272\n",
"0.00043579712\n",
"0.00042489968\n",
"0.00041232834\n",
"0.00040144968\n",
"0.00039016161\n",
"0.00037896418\n",
"0.0003684666\n",
"0.00035823602\n",
"0.00034837553\n",
"0.00033894175\n",
"0.0003298471\n",
"0.00032139392\n",
"0.00031297834\n",
"0.0003046702\n",
"0.00029719202\n",
"0.00029006044\n",
"0.00028194688\n",
"0.00027473248\n",
"0.00026784418\n",
"0.00026087818\n",
"0.00025449175\n",
"0.00024836036\n",
"0.00024265141\n",
"0.00023689025\n",
"0.00023100003\n",
"0.00022525838\n",
"0.00022012439\n",
"0.00021410642\n",
"0.00021031541\n",
"0.00020558643\n",
"0.00020068254\n",
"0.00019690076\n",
"0.00019188935\n",
"0.00018755638\n",
"0.00018369195\n",
"0.0001789029\n",
"0.0001755635\n",
"0.0001713126\n",
"0.00016828944\n",
"0.00016430137\n",
"0.00016102222\n",
"0.0001580959\n",
"0.00015517048\n",
"0.0001515107\n",
"0.00014860292\n",
"0.00014533353\n",
"0.0001421867\n",
"0.00013938035\n",
"0.00013639471\n",
"0.00013359617\n",
"0.00013160022\n",
"0.000128804\n",
"0.00012634753\n",
"0.00012422545\n",
"0.00012169797\n",
"0.00011960982\n",
"0.00011716273\n",
"0.00011489865\n",
"0.00011276081\n",
"0.000110744135\n",
"0.00010880789\n",
"0.000107240965\n",
"0.00010532818\n",
"0.00010319804\n",
"0.00010140792\n",
"9.983361e-05\n",
"9.834589e-05\n",
"9.67456e-05\n",
"9.513831e-05\n",
"9.343762e-05\n",
"9.201749e-05\n",
"9.0557965e-05\n",
"8.896999e-05\n",
"8.72736e-05\n",
"8.5892316e-05\n",
"8.469775e-05\n",
"8.290162e-05\n",
"8.2124745e-05\n",
"8.0443606e-05\n",
"7.909554e-05\n",
"7.781605e-05\n",
"7.646244e-05\n",
"7.543016e-05\n",
"7.415263e-05\n",
"7.3085306e-05\n",
"7.193201e-05\n",
"7.052146e-05\n",
"6.972016e-05\n",
"6.8984824e-05\n",
"6.792543e-05\n",
"6.698833e-05\n",
"6.534216e-05\n",
"6.4625354e-05\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"\"\"\"\n",
"A fully-connected ReLU network with one hidden layer and no biases, trained to\n",
"predict y from x by minimizing squared Euclidean distance.\n",
"\n",
"This implementation uses basic TensorFlow operations to set up a computational\n",
"graph, then executes the graph many times to actually train the network.\n",
"\n",
"One of the main differences between TensorFlow and PyTorch is that TensorFlow\n",
"uses static computational graphs while PyTorch uses dynamic computational\n",
"graphs.\n",
"\n",
"In TensorFlow we first set up the computational graph, then execute the same\n",
"graph many times.\n",
"\"\"\"\n",
"\n",
"# First we set up the computational graph:\n",
"\n",
"# N is batch size; D_in is input dimension;\n",
"# H is hidden dimension; D_out is output dimension.\n",
"N, D_in, H, D_out = 64, 1000, 100, 10\n",
"\n",
"# Create placeholders for the input and target data; these will be filled\n",
"# with real data when we execute the graph.\n",
"x = tf.placeholder(tf.float32, shape=(None, D_in))\n",
"y = tf.placeholder(tf.float32, shape=(None, D_out))\n",
"\n",
"# Create Variables for the weights and initialize them with random data.\n",
"# A TensorFlow Variable persists its value across executions of the graph.\n",
"w1 = tf.Variable(tf.random_normal((D_in, H)))\n",
"w2 = tf.Variable(tf.random_normal((H, D_out)))\n",
"\n",
"# Forward pass: Compute the predicted y using operations on TensorFlow Tensors.\n",
"# Note that this code does not actually perform any numeric operations; it\n",
"# merely sets up the computational graph that we will later execute.\n",
"h = tf.matmul(x, w1)\n",
"h_relu = tf.maximum(h, tf.zeros(1))\n",
"y_pred = tf.matmul(h_relu, w2)\n",
"\n",
"# Compute loss using operations on TensorFlow Tensors\n",
"loss = tf.reduce_sum((y - y_pred) ** 2.0)\n",
"\n",
"# Compute gradient of the loss with respect to w1 and w2.\n",
"grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])\n",
"\n",
"# Update the weights using gradient descent. To actually update the weights\n",
"# we need to evaluate new_w1 and new_w2 when executing the graph. Note that\n",
"# in TensorFlow the the act of updating the value of the weights is part of\n",
"# the computational graph; in PyTorch this happens outside the computational\n",
"# graph.\n",
"learning_rate = 1e-6\n",
"new_w1 = w1.assign(w1 - learning_rate * grad_w1)\n",
"new_w2 = w2.assign(w2 - learning_rate * grad_w2)\n",
"\n",
"# Now we have built our computational graph, so we enter a TensorFlow session to\n",
"# actually execute the graph.\n",
"with tf.Session() as sess:\n",
" # Run the graph once to initialize the Variables w1 and w2.\n",
" sess.run(tf.global_variables_initializer())\n",
"\n",
" # Create numpy arrays holding the actual data for the inputs x and targets y\n",
" x_value = np.random.randn(N, D_in)\n",
" y_value = np.random.randn(N, D_out)\n",
" for _ in range(500):\n",
" # Execute the graph many times. Each time it executes we want to bind\n",
" # x_value to x and y_value to y, specified with the feed_dict argument.\n",
" # Each time we execute the graph we want to compute the values for loss,\n",
" # new_w1, and new_w2; the values of these Tensors are returned as numpy\n",
" # arrays.\n",
" loss_value, _, _ = sess.run([loss, new_w1, new_w2],\n",
" feed_dict={x: x_value, y: y_value})\n",
" print(loss_value)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 24688302.0\n",
"1 17984838.0\n",
"2 14542883.0\n",
"3 12217404.0\n",
"4 10226319.0\n",
"5 8339861.5\n",
"6 6578238.0\n",
"7 5038621.5\n",
"8 3776320.5\n",
"9 2804629.5\n",
"10 2084264.875\n",
"11 1565664.25\n",
"12 1194818.375\n",
"13 930355.5\n",
"14 739639.625\n",
"15 600507.625\n",
"16 496891.84375\n",
"17 418130.15625\n",
"18 356857.78125\n",
"19 308160.625\n",
"20 268691.0\n",
"21 236098.40625\n",
"22 208795.15625\n",
"23 185622.296875\n",
"24 165742.59375\n",
"25 148533.828125\n",
"26 133525.359375\n",
"27 120362.0390625\n",
"28 108754.03125\n",
"29 98481.6328125\n",
"30 89344.921875\n",
"31 81193.390625\n",
"32 73899.1328125\n",
"33 67357.7109375\n",
"34 61493.2421875\n",
"35 56223.859375\n",
"36 51466.671875\n",
"37 47166.6875\n",
"38 43270.78515625\n",
"39 39742.6015625\n",
"40 36539.7421875\n",
"41 33624.94921875\n",
"42 30971.720703125\n",
"43 28551.142578125\n",
"44 26338.978515625\n",
"45 24316.609375\n",
"46 22465.20703125\n",
"47 20767.7890625\n",
"48 19212.06640625\n",
"49 17785.044921875\n",
"50 16473.28515625\n",
"51 15266.7978515625\n",
"52 14156.47265625\n",
"53 13133.5947265625\n",
"54 12191.138671875\n",
"55 11321.9384765625\n",
"56 10519.5732421875\n",
"57 9778.431640625\n",
"58 9093.568359375\n",
"59 8460.59375\n",
"60 7875.666015625\n",
"61 7334.12451171875\n",
"62 6832.69140625\n",
"63 6368.00341796875\n",
"64 5937.0830078125\n",
"65 5537.64599609375\n",
"66 5166.98388671875\n",
"67 4822.73388671875\n",
"68 4502.95458984375\n",
"69 4205.8037109375\n",
"70 3929.57177734375\n",
"71 3672.575927734375\n",
"72 3433.62744140625\n",
"73 3211.204345703125\n",
"74 3004.191162109375\n",
"75 2811.431396484375\n",
"76 2631.81103515625\n",
"77 2464.303955078125\n",
"78 2308.0771484375\n",
"79 2162.40673828125\n",
"80 2026.4805908203125\n",
"81 1899.532958984375\n",
"82 1780.9422607421875\n",
"83 1670.201904296875\n",
"84 1566.7562255859375\n",
"85 1470.1854248046875\n",
"86 1379.8446044921875\n",
"87 1295.3297119140625\n",
"88 1216.2265625\n",
"89 1142.2291259765625\n",
"90 1072.973876953125\n",
"91 1008.1453247070312\n",
"92 947.4297485351562\n",
"93 890.5519409179688\n",
"94 837.2857666015625\n",
"95 787.373291015625\n",
"96 740.5667724609375\n",
"97 696.6860961914062\n",
"98 655.5277709960938\n",
"99 616.909912109375\n",
"100 580.7116088867188\n",
"101 546.7200927734375\n",
"102 514.8139038085938\n",
"103 484.86077880859375\n",
"104 456.7369384765625\n",
"105 430.322509765625\n",
"106 405.5072937011719\n",
"107 382.189697265625\n",
"108 360.2792053222656\n",
"109 339.68060302734375\n",
"110 320.32708740234375\n",
"111 302.1233215332031\n",
"112 285.00146484375\n",
"113 268.8934631347656\n",
"114 253.73870849609375\n",
"115 239.47573852539062\n",
"116 226.0470733642578\n",
"117 213.40296936035156\n",
"118 201.50918579101562\n",
"119 190.30291748046875\n",
"120 179.75070190429688\n",
"121 169.80784606933594\n",
"122 160.43560791015625\n",
"123 151.60121154785156\n",
"124 143.27548217773438\n",
"125 135.42539978027344\n",
"126 128.02378845214844\n",
"127 121.04232025146484\n",
"128 114.46221160888672\n",
"129 108.25297546386719\n",
"130 102.39306640625\n",
"131 96.86213684082031\n",
"132 91.64041900634766\n",
"133 86.71172332763672\n",
"134 82.06002807617188\n",
"135 77.66609191894531\n",
"136 73.51763916015625\n",
"137 69.59766387939453\n",
"138 65.89807891845703\n",
"139 62.40123748779297\n",
"140 59.096771240234375\n",
"141 55.97258377075195\n",
"142 53.02041244506836\n",
"143 50.22927474975586\n",
"144 47.58990478515625\n",
"145 45.09389114379883\n",
"146 42.73532485961914\n",
"147 40.50306701660156\n",
"148 38.39479064941406\n",
"149 36.39691925048828\n",
"150 34.50675964355469\n",
"151 32.71832275390625\n",
"152 31.026700973510742\n",
"153 29.425796508789062\n",
"154 27.910646438598633\n",
"155 26.475162506103516\n",
"156 25.116464614868164\n",
"157 23.82939910888672\n",
"158 22.611515045166016\n",
"159 21.45804786682129\n",
"160 20.36470603942871\n",
"161 19.32872200012207\n",
"162 18.347423553466797\n",
"163 17.417695999145508\n",
"164 16.536418914794922\n",
"165 15.701050758361816\n",
"166 14.908943176269531\n",
"167 14.15771484375\n",
"168 13.446287155151367\n",
"169 12.771245956420898\n",
"170 12.131441116333008\n",
"171 11.524349212646484\n",
"172 10.947635650634766\n",
"173 10.401337623596191\n",
"174 9.882925987243652\n",
"175 9.391077995300293\n",
"176 8.924080848693848\n",
"177 8.481295585632324\n",
"178 8.061117172241211\n",
"179 7.661981105804443\n",
"180 7.283514499664307\n",
"181 6.923940181732178\n",
"182 6.582524299621582\n",
"183 6.258670806884766\n",
"184 5.950675964355469\n",
"185 5.658329010009766\n",
"186 5.380758285522461\n",
"187 5.117119789123535\n",
"188 4.866908073425293\n",
"189 4.628957271575928\n",
"190 4.402910232543945\n",
"191 4.18823766708374\n",
"192 3.9842240810394287\n",
"193 3.790332555770874\n",
"194 3.606066942214966\n",
"195 3.4311227798461914\n",
"196 3.2645761966705322\n",
"197 3.106358289718628\n",
"198 2.956329584121704\n",
"199 2.8134841918945312\n",
"200 2.6773831844329834\n",
"201 2.548379421234131\n",
"202 2.4255940914154053\n",
"203 2.308701515197754\n",
"204 2.1976122856140137\n",
"205 2.092134475708008\n",
"206 1.9916737079620361\n",
"207 1.8961156606674194\n",
"208 1.8052055835723877\n",
"209 1.718900442123413\n",
"210 1.6366509199142456\n",
"211 1.5583386421203613\n",
"212 1.4838244915008545\n",
"213 1.4132423400878906\n",
"214 1.3459393978118896\n",
"215 1.2817420959472656\n",
"216 1.2207046747207642\n",
"217 1.1626098155975342\n",
"218 1.107551097869873\n",
"219 1.054832100868225\n",
"220 1.0048843622207642\n",
"221 0.9573338627815247\n",
"222 0.9119561314582825\n",
"223 0.8686810731887817\n",
"224 0.8277258276939392\n",
"225 0.788608729839325\n",
"226 0.7513725757598877\n",
"227 0.7158702611923218\n",
"228 0.6821508407592773\n",
"229 0.6500615477561951\n",
"230 0.619461715221405\n",
"231 0.5902796983718872\n",
"232 0.5625408291816711\n",
"233 0.5360639691352844\n",
"234 0.5109221339225769\n",
"235 0.4869716167449951\n",
"236 0.4642111659049988\n",
"237 0.4423781633377075\n",
"238 0.4216068983078003\n",
"239 0.40193697810173035\n",
"240 0.38310009241104126\n",
"241 0.3652460277080536\n",
"242 0.348175972700119\n",
"243 0.33189767599105835\n",
"244 0.3163726031780243\n",
"245 0.30159080028533936\n",
"246 0.28754404187202454\n",
"247 0.27415674924850464\n",
"248 0.26136279106140137\n",
"249 0.2492113560438156\n",
"250 0.23764055967330933\n",
"251 0.22657275199890137\n",
"252 0.21601782739162445\n",
"253 0.20601952075958252\n",
"254 0.1964670568704605\n",
"255 0.187359020113945\n",
"256 0.17864982783794403\n",
"257 0.1704040914773941\n",
"258 0.16252264380455017\n",
"259 0.15498709678649902\n",
"260 0.14781248569488525\n",
"261 0.140987366437912\n",
"262 0.1344553679227829\n",
"263 0.12824320793151855\n",
"264 0.12234124541282654\n",
"265 0.1166907250881195\n",
"266 0.1113230511546135\n",
"267 0.10620296001434326\n",
"268 0.10131487250328064\n",
"269 0.09665443003177643\n",
"270 0.09221497178077698\n",
"271 0.08798132091760635\n",
"272 0.08392312377691269\n",
"273 0.08008478581905365\n",
"274 0.07640678435564041\n",
"275 0.07290326058864594\n",
"276 0.06956953555345535\n",
"277 0.0663725733757019\n",
"278 0.06333693116903305\n",
"279 0.06042061746120453\n",
"280 0.0576675646007061\n",
"281 0.055027298629283905\n",
"282 0.05251595750451088\n",
"283 0.05009906366467476\n",
"284 0.047803375869989395\n",
"285 0.04562835767865181\n",
"286 0.043548718094825745\n",
"287 0.04157324507832527\n",
"288 0.039687953889369965\n",
"289 0.037872858345508575\n",
"290 0.03613417595624924\n",
"291 0.034499578177928925\n",
"292 0.032925188541412354\n",
"293 0.03143469616770744\n",
"294 0.03000658005475998\n",
"295 0.028628332540392876\n",
"296 0.02732955664396286\n",
"297 0.026118163019418716\n",
"298 0.024930033832788467\n",
"299 0.02379368618130684\n",
"300 0.02270844765007496\n",
"301 0.021688120439648628\n",
"302 0.0207170732319355\n",
"303 0.019782159477472305\n",
"304 0.01889665797352791\n",
"305 0.018045848235487938\n",
"306 0.01723134145140648\n",
"307 0.016461752355098724\n",
"308 0.015723150223493576\n",
"309 0.015016935765743256\n",
"310 0.01434311456978321\n",
"311 0.013705356977880001\n",
"312 0.013095956295728683\n",
"313 0.012517692521214485\n",
"314 0.01196172647178173\n",
"315 0.011421596631407738\n",
"316 0.01091529056429863\n",
"317 0.010437062941491604\n",
"318 0.00997946597635746\n",
"319 0.009533328004181385\n",
"320 0.009121168404817581\n",
"321 0.008721508085727692\n",
"322 0.008335432969033718\n",
"323 0.00798106286674738\n",
"324 0.007627126760780811\n",
"325 0.00730485562235117\n",
"326 0.0069808997213840485\n",
"327 0.0066769300028681755\n",
"328 0.006390679627656937\n",
"329 0.006114493124186993\n",
"330 0.00584913045167923\n",
"331 0.005599178373813629\n",
"332 0.005361250136047602\n",
"333 0.005131804384291172\n",
"334 0.004914569202810526\n",
"335 0.004706747829914093\n",
"336 0.004509361926466227\n",
"337 0.004315936006605625\n",
"338 0.004138599615544081\n",
"339 0.003964624833315611\n",
"340 0.0037977350875735283\n",
"341 0.003639361821115017\n",
"342 0.003493860363960266\n",
"343 0.0033509270288050175\n",
"344 0.003214005148038268\n",
"345 0.0030827545560896397\n",
"346 0.0029589012265205383\n",
"347 0.002841042587533593\n",
"348 0.002726183272898197\n",
"349 0.002619938226416707\n",
"350 0.002515277359634638\n",
"351 0.0024151599500328302\n",
"352 0.002318766200914979\n",
"353 0.0022258039098232985\n",
"354 0.002140473108738661\n",
"355 0.0020580259151756763\n",
"356 0.001978270709514618\n",
"357 0.0019029895775020123\n",
"358 0.0018288800492882729\n",
"359 0.0017595937242731452\n",
"360 0.0016915813321247697\n",
"361 0.0016288674669340253\n",
"362 0.0015665694372728467\n",
"363 0.001508314162492752\n",
"364 0.001454052748158574\n",
"365 0.0013997266069054604\n",
"366 0.001349251950159669\n",
"367 0.0012997810263186693\n",
"368 0.0012537735747173429\n",
"369 0.0012090876698493958\n",
"370 0.0011649469379335642\n",
"371 0.0011251724790781736\n",
"372 0.001086117117665708\n",
"373 0.001048664445988834\n",
"374 0.001013571280054748\n",
"375 0.000976726645603776\n",
"376 0.0009436089312657714\n",
"377 0.0009109410457313061\n",
"378 0.0008817656198516488\n",
"379 0.0008514215587638319\n",
"380 0.0008234234410338104\n",
"381 0.0007956388290040195\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"382 0.0007705247262492776\n",
"383 0.0007460417109541595\n",
"384 0.0007212982745841146\n",
"385 0.0006969314999878407\n",
"386 0.0006756981019861996\n",
"387 0.0006541507318615913\n",
"388 0.0006323024863377213\n",
"389 0.0006144408253021538\n",
"390 0.000595973979216069\n",
"391 0.0005770920543000102\n",
"392 0.000559464911930263\n",
"393 0.0005413927137851715\n",
"394 0.0005257160519249737\n",
"395 0.0005113421939313412\n",
"396 0.0004966298583894968\n",
"397 0.000481585506349802\n",
"398 0.0004678669210989028\n",
"399 0.00045517252874560654\n",
"400 0.00044213910587131977\n",
"401 0.00042900926200672984\n",
"402 0.0004169284366071224\n",
"403 0.00040626025293022394\n",
"404 0.0003948296362068504\n",
"405 0.0003834187227766961\n",
"406 0.0003736388753168285\n",
"407 0.00036392814945429564\n",
"408 0.0003543612256180495\n",
"409 0.00034513106220401824\n",
"410 0.0003364191798027605\n",
"411 0.00032711352105252445\n",
"412 0.00031908092205412686\n",
"413 0.0003105267242062837\n",
"414 0.000302838918287307\n",
"415 0.00029511633329093456\n",
"416 0.000287334609311074\n",
"417 0.00028105394449084997\n",
"418 0.00027363712433725595\n",
"419 0.00026700543821789324\n",
"420 0.00026004898245446384\n",
"421 0.00025392131647095084\n",
"422 0.000247285672230646\n",
"423 0.0002419824304524809\n",
"424 0.00023607343609910458\n",
"425 0.0002303303190274164\n",
"426 0.0002251435216749087\n",
"427 0.00021941975865047425\n",
"428 0.00021437769464682788\n",
"429 0.00020955922082066536\n",
"430 0.00020557538664434105\n",
"431 0.00020082987612113357\n",
"432 0.0001965878764167428\n",
"433 0.000191882936633192\n",
"434 0.00018795413780026138\n",
"435 0.00018392673518974334\n",
"436 0.00018066607299260795\n",
"437 0.00017658348951954395\n",
"438 0.00017303903587162495\n",
"439 0.00016955489991232753\n",
"440 0.00016591577150393277\n",
"441 0.00016194916679523885\n",
"442 0.00015838695981074125\n",
"443 0.00015563251508865505\n",
"444 0.00015258743951562792\n",
"445 0.00014952228229958564\n",
"446 0.00014703115448355675\n",
"447 0.00014396110782399774\n",
"448 0.00014084955910220742\n",
"449 0.00013792573008686304\n",
"450 0.00013549154391512275\n",
"451 0.00013308131019584835\n",
"452 0.00013055365707259625\n",
"453 0.00012804055586457253\n",
"454 0.00012535802670754492\n",
"455 0.00012324642739258707\n",
"456 0.00012121061445213854\n",
"457 0.00011869342415593565\n",
"458 0.00011700656614266336\n",
"459 0.00011477694351924583\n",
"460 0.0001128563962993212\n",
"461 0.00011103119322797284\n",
"462 0.00010885439405683428\n",
"463 0.00010673737415345386\n",
"464 0.0001053597079589963\n",
"465 0.0001036376052070409\n",
"466 0.00010176044452236965\n",
"467 9.98697432805784e-05\n",
"468 9.798920655157417e-05\n",
"469 9.643139492254704e-05\n",
"470 9.48067317949608e-05\n",
"471 9.311249596066773e-05\n",
"472 9.161644265986979e-05\n",
"473 9.022918675327674e-05\n",
"474 8.881841495167464e-05\n",
"475 8.758777403272688e-05\n",
"476 8.624640759080648e-05\n",
"477 8.479460666421801e-05\n",
"478 8.340038766618818e-05\n",
"479 8.222086034948006e-05\n",
"480 8.082618296612054e-05\n",
"481 7.959656795719638e-05\n",
"482 7.82867573434487e-05\n",
"483 7.693078077863902e-05\n",
"484 7.604445272590965e-05\n",
"485 7.499106141040102e-05\n",
"486 7.391876715701073e-05\n",
"487 7.290508801816031e-05\n",
"488 7.188640302047133e-05\n",
"489 7.109644502634183e-05\n",
"490 6.973437120905146e-05\n",
"491 6.883328023832291e-05\n",
"492 6.789986218791455e-05\n",
"493 6.697298522340134e-05\n",
"494 6.580616172868758e-05\n",
"495 6.498537550214678e-05\n",
"496 6.426713662222028e-05\n",
"497 6.3318271713797e-05\n",
"498 6.261411908781156e-05\n",
"499 6.193968147272244e-05\n"
]
}
],
"source": [
"import torch\n",
"\n",
"\"\"\"\n",
"A fully-connected ReLU network with one hidden layer and no biases, trained to\n",
"predict y from x by minimizing squared Euclidean distance.\n",
"\n",
"This implementation computes the forward pass using operations on PyTorch\n",
"Tensors, and uses PyTorch autograd to compute gradients.\n",
"\n",
"When we create a PyTorch Tensor with requires_grad=True, then operations\n",
"involving that Tensor will not just compute values; they will also build up\n",
"a computational graph in the background, allowing us to easily backpropagate\n",
"through the graph to compute gradients of some downstream (scalar) loss with\n",
"respect to a Tensor. Concretely if x is a Tensor with x.requires_grad == True\n",
"then after backpropagation x.grad will be another Tensor holding the gradient\n",
"of x with respect to some scalar value.\n",
"\"\"\"\n",
"\n",
"device = torch.device('cpu')\n",
"# device = torch.device('cuda') # Uncomment this to run on GPU\n",
"\n",
"# N is batch size; D_in is input dimension;\n",
"# H is hidden dimension; D_out is output dimension.\n",
"N, D_in, H, D_out = 64, 1000, 100, 10\n",
"\n",
"# Create random Tensors to hold input and outputs\n",
"x = torch.randn(N, D_in, device=device)\n",
"y = torch.randn(N, D_out, device=device)\n",
"\n",
"# Create random Tensors for weights; setting requires_grad=True means that we\n",
"# want to compute gradients for these Tensors during the backward pass.\n",
"w1 = torch.randn(D_in, H, device=device, requires_grad=True)\n",
"w2 = torch.randn(H, D_out, device=device, requires_grad=True)\n",
"\n",
"learning_rate = 1e-6\n",
"for t in range(500):\n",
" # Forward pass: compute predicted y using operations on Tensors. Since w1 and\n",
" # w2 have requires_grad=True, operations involving these Tensors will cause\n",
" # PyTorch to build a computational graph, allowing automatic computation of\n",
" # gradients. Since we are no longer implementing the backward pass by hand we\n",
" # don't need to keep references to intermediate values.\n",
" y_pred = x.mm(w1).clamp(min=0).mm(w2)\n",
" \n",
" # Compute and print loss. Loss is a Tensor of shape (), and loss.item()\n",
" # is a Python number giving its value.\n",
" loss = (y_pred - y).pow(2).sum()\n",
" print(t, loss.item())\n",
"\n",
" # Use autograd to compute the backward pass. This call will compute the\n",
" # gradient of loss with respect to all Tensors with requires_grad=True.\n",
" # After this call w1.grad and w2.grad will be Tensors holding the gradient\n",
" # of the loss with respect to w1 and w2 respectively.\n",
" loss.backward()\n",
"\n",
" # Update weights using gradient descent. For this step we just want to mutate\n",
" # the values of w1 and w2 in-place; we don't want to build up a computational\n",
" # graph for the update steps, so we use the torch.no_grad() context manager\n",
" # to prevent PyTorch from building a computational graph for the updates\n",
" with torch.no_grad():\n",
" w1 -= learning_rate * w1.grad\n",
" w2 -= learning_rate * w2.grad\n",
"\n",
" # Manually zero the gradients after running the backward pass\n",
" w1.grad.zero_()\n",
" w2.grad.zero_()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 30072014.0\n",
"1 21278482.0\n",
"2 15539603.0\n",
"3 11161736.0\n",
"4 7837348.5\n",
"5 5445020.5\n",
"6 3813310.25\n",
"7 2729170.5\n",
"8 2013111.0\n",
"9 1533340.875\n",
"10 1203696.25\n",
"11 969689.0\n",
"12 798104.75\n",
"13 668597.3125\n",
"14 566800.0625\n",
"15 485646.75\n",
"16 419749.28125\n",
"17 365319.9375\n",
"18 319870.09375\n",
"19 281881.78125\n",
"20 249538.796875\n",
"21 221924.65625\n",
"22 198114.265625\n",
"23 177447.484375\n",
"24 159425.25\n",
"25 143660.0625\n",
"26 129804.015625\n",
"27 117571.8671875\n",
"28 106744.546875\n",
"29 97149.6953125\n",
"30 88597.2265625\n",
"31 80957.7578125\n",
"32 74114.96875\n",
"33 67969.6796875\n",
"34 62439.3125\n",
"35 57455.203125\n",
"36 52949.84375\n",
"37 48873.2109375\n",
"38 45176.828125\n",
"39 41816.546875\n",
"40 38755.08203125\n",
"41 35961.484375\n",
"42 33409.0546875\n",
"43 31072.8203125\n",
"44 28931.333984375\n",
"45 26964.8671875\n",
"46 25157.8203125\n",
"47 23494.16796875\n",
"48 21961.365234375\n",
"49 20545.81640625\n",
"50 19238.94140625\n",
"51 18030.287109375\n",
"52 16911.6015625\n",
"53 15874.5244140625\n",
"54 14912.42578125\n",
"55 14018.6552734375\n",
"56 13188.0126953125\n",
"57 12415.0078125\n",
"58 11694.814453125\n",
"59 11023.69140625\n",
"60 10397.52734375\n",
"61 9812.88671875\n",
"62 9266.6748046875\n",
"63 8755.869140625\n",
"64 8277.8271484375\n",
"65 7830.02880859375\n",
"66 7410.328125\n",
"67 7016.7724609375\n",
"68 6647.421875\n",
"69 6300.5244140625\n",
"70 5974.5927734375\n",
"71 5668.1396484375\n",
"72 5379.82568359375\n",
"73 5108.40966796875\n",
"74 4852.75732421875\n",
"75 4611.75\n",
"76 4384.4951171875\n",
"77 4170.056640625\n",
"78 3967.59716796875\n",
"79 3776.37548828125\n",
"80 3595.642578125\n",
"81 3424.756103515625\n",
"82 3263.154052734375\n",
"83 3110.21923828125\n",
"84 2965.375244140625\n",
"85 2828.2041015625\n",
"86 2698.21337890625\n",
"87 2575.033935546875\n",
"88 2458.2099609375\n",
"89 2347.3671875\n",
"90 2242.168212890625\n",
"91 2142.27294921875\n",
"92 2047.3822021484375\n",
"93 1957.19287109375\n",
"94 1871.457275390625\n",
"95 1789.94287109375\n",
"96 1712.406494140625\n",
"97 1644.8389892578125\n",
"98 1581.1982421875\n",
"99 1520.4932861328125\n",
"100 1462.5537109375\n",
"101 1407.2244873046875\n",
"102 1354.361572265625\n",
"103 1303.8319091796875\n",
"104 1255.5167236328125\n",
"105 1209.3021240234375\n",
"106 1165.0728759765625\n",
"107 1122.72265625\n",
"108 1082.1737060546875\n",
"109 1043.3421630859375\n",
"110 1006.0917358398438\n",
"111 970.3775634765625\n",
"112 936.124267578125\n",
"113 903.264404296875\n",
"114 871.7230224609375\n",
"115 841.445068359375\n",
"116 812.3616333007812\n",
"117 784.4229125976562\n",
"118 757.5758056640625\n",
"119 731.7740478515625\n",
"120 706.9667358398438\n",
"121 683.1082763671875\n",
"122 660.1576538085938\n",
"123 638.0703125\n",
"124 616.8150634765625\n",
"125 596.3526000976562\n",
"126 576.65234375\n",
"127 557.674560546875\n",
"128 539.3936767578125\n",
"129 521.7766723632812\n",
"130 504.7982177734375\n",
"131 488.43304443359375\n",
"132 472.65509033203125\n",
"133 457.4384460449219\n",
"134 442.76123046875\n",
"135 428.6019287109375\n",
"136 414.9403991699219\n",
"137 401.7532958984375\n",
"138 389.02825927734375\n",
"139 376.74017333984375\n",
"140 364.876220703125\n",
"141 353.41754150390625\n",
"142 342.3507995605469\n",
"143 331.6587219238281\n",
"144 321.33111572265625\n",
"145 311.3492126464844\n",
"146 301.7040710449219\n",
"147 292.3786315917969\n",
"148 283.364013671875\n",
"149 274.649169921875\n",
"150 266.223388671875\n",
"151 258.0718688964844\n",
"152 250.1898651123047\n",
"153 242.56491088867188\n",
"154 235.18743896484375\n",
"155 228.050537109375\n",
"156 221.14320373535156\n",
"157 214.4588623046875\n",
"158 207.98941040039062\n",
"159 201.726806640625\n",
"160 195.66470336914062\n",
"161 189.7981719970703\n",
"162 184.11663818359375\n",
"163 178.61387634277344\n",
"164 173.2858123779297\n",
"165 168.12429809570312\n",
"166 163.12686157226562\n",
"167 158.28370666503906\n",
"168 153.59303283691406\n",
"169 149.0480194091797\n",
"170 144.6435089111328\n",
"171 140.3773193359375\n",
"172 136.24095153808594\n",
"173 132.23362731933594\n",
"174 128.3489532470703\n",
"175 124.58283996582031\n",
"176 120.9339370727539\n",
"177 117.39618682861328\n",
"178 113.9661865234375\n",
"179 110.6396713256836\n",
"180 107.41590881347656\n",
"181 104.28995513916016\n",
"182 101.25743103027344\n",
"183 98.3170166015625\n",
"184 95.4657974243164\n",
"185 92.7005844116211\n",
"186 90.01663208007812\n",
"187 87.4150161743164\n",
"188 84.89067077636719\n",
"189 82.44154357910156\n",
"190 80.0657730102539\n",
"191 77.76116943359375\n",
"192 75.52415466308594\n",
"193 73.35409545898438\n",
"194 71.24851989746094\n",
"195 69.2047119140625\n",
"196 67.22196960449219\n",
"197 65.29727935791016\n",
"198 63.42928695678711\n",
"199 61.61680603027344\n",
"200 59.857940673828125\n",
"201 58.149662017822266\n",
"202 56.492183685302734\n",
"203 54.88319778442383\n",
"204 53.3216552734375\n",
"205 51.80501174926758\n",
"206 50.332881927490234\n",
"207 48.90410614013672\n",
"208 47.516014099121094\n",
"209 46.168601989746094\n",
"210 44.860679626464844\n",
"211 43.59085464477539\n",
"212 42.357364654541016\n",
"213 41.16017150878906\n",
"214 39.99699020385742\n",
"215 38.867637634277344\n",
"216 37.77103042602539\n",
"217 36.70572280883789\n",
"218 35.67122268676758\n",
"219 34.666969299316406\n",
"220 33.69096374511719\n",
"221 32.743507385253906\n",
"222 31.822818756103516\n",
"223 30.92894744873047\n",
"224 30.06052017211914\n",
"225 29.216651916503906\n",
"226 28.397817611694336\n",
"227 27.601322174072266\n",
"228 26.82806968688965\n",
"229 26.07698631286621\n",
"230 25.34747886657715\n",
"231 24.6385498046875\n",
"232 23.949382781982422\n",
"233 23.280410766601562\n",
"234 22.630298614501953\n",
"235 21.99837875366211\n",
"236 21.384599685668945\n",
"237 20.788803100585938\n",
"238 20.209287643432617\n",
"239 19.646053314208984\n",
"240 19.098949432373047\n",
"241 18.567956924438477\n",
"242 18.051021575927734\n",
"243 17.54905891418457\n",
"244 17.061302185058594\n",
"245 16.587379455566406\n",
"246 16.12694549560547\n",
"247 15.67921257019043\n",
"248 15.243736267089844\n",
"249 14.821207046508789\n",
"250 14.410493850708008\n",
"251 14.010963439941406\n",
"252 13.623039245605469\n",
"253 13.24581527709961\n",
"254 12.879338264465332\n",
"255 12.523193359375\n",
"256 12.176714897155762\n",
"257 11.840224266052246\n",
"258 11.513036727905273\n",
"259 11.19500732421875\n",
"260 10.88607120513916\n",
"261 10.58561897277832\n",
"262 10.293493270874023\n",
"263 10.009596824645996\n",
"264 9.733649253845215\n",
"265 9.465527534484863\n",
"266 9.204741477966309\n",
"267 8.951228141784668\n",
"268 8.704797744750977\n",
"269 8.465054512023926\n",
"270 8.232160568237305\n",
"271 8.005952835083008\n",
"272 7.786048889160156\n",
"273 7.571745872497559\n",
"274 7.363881587982178\n",
"275 7.161560535430908\n",
"276 6.964933395385742\n",
"277 6.773778438568115\n",
"278 6.587972640991211\n",
"279 6.4071269035339355\n",
"280 6.231436729431152\n",
"281 6.060788631439209\n",
"282 5.89458703994751\n",
"283 5.73313570022583\n",
"284 5.576119899749756\n",
"285 5.423434257507324\n",
"286 5.274935722351074\n",
"287 5.130599021911621\n",
"288 4.990262508392334\n",
"289 4.853690147399902\n",
"290 4.721065998077393\n",
"291 4.592111110687256\n",
"292 4.46661376953125\n",
"293 4.344718933105469\n",
"294 4.2261457443237305\n",
"295 4.110623836517334\n",
"296 3.99847674369812\n",
"297 3.889352798461914\n",
"298 3.7832698822021484\n",
"299 3.6800806522369385\n",
"300 3.5797905921936035\n",
"301 3.4822211265563965\n",
"302 3.3873722553253174\n",
"303 3.2951035499572754\n",
"304 3.205399751663208\n",
"305 3.118119239807129\n",
"306 3.0331027507781982\n",
"307 2.950713634490967\n",
"308 2.870450735092163\n",
"309 2.7923619747161865\n",
"310 2.7164535522460938\n",
"311 2.642550230026245\n",
"312 2.5708224773406982\n",
"313 2.5008833408355713\n",
"314 2.433067798614502\n",
"315 2.366969108581543\n",
"316 2.3027470111846924\n",
"317 2.2402265071868896\n",
"318 2.1794466972351074\n",
"319 2.1203787326812744\n",
"320 2.062870979309082\n",
"321 2.007028102874756\n",
"322 1.952622652053833\n",
"323 1.8996343612670898\n",
"324 1.8482120037078857\n",
"325 1.7981364727020264\n",
"326 1.7493963241577148\n",
"327 1.7021578550338745\n",
"328 1.656005859375\n",
"329 1.6111729145050049\n",
"330 1.5675476789474487\n",
"331 1.5253342390060425\n",
"332 1.4840267896652222\n",
"333 1.4439213275909424\n",
"334 1.4048900604248047\n",
"335 1.3669472932815552\n",
"336 1.3300553560256958\n",
"337 1.2941076755523682\n",
"338 1.2591948509216309\n",
"339 1.2251873016357422\n",
"340 1.1921114921569824\n",
"341 1.1600029468536377\n",
"342 1.1286848783493042\n",
"343 1.0983271598815918\n",
"344 1.068663239479065\n",
"345 1.0398435592651367\n",
"346 1.0118087530136108\n",
"347 0.9846287965774536\n",
"348 0.9580874443054199\n",
"349 0.9322843551635742\n",
"350 0.9071281552314758\n",
"351 0.882738471031189\n",
"352 0.8589873909950256\n",
"353 0.8358998894691467\n",
"354 0.8134605288505554\n",
"355 0.7915393114089966\n",
"356 0.7702587246894836\n",
"357 0.7495856285095215\n",
"358 0.7294440269470215\n",
"359 0.7098613381385803\n",
"360 0.6907796859741211\n",
"361 0.6722580790519714\n",
"362 0.6541950106620789\n",
"363 0.6366435289382935\n",
"364 0.619549572467804\n",
"365 0.6029466986656189\n",
"366 0.5867413282394409\n",
"367 0.5710874199867249\n",
"368 0.5557659268379211\n",
"369 0.5408589839935303\n",
"370 0.5263494253158569\n",
"371 0.5122079253196716\n",
"372 0.49854475259780884\n",
"373 0.48518267273902893\n",
"374 0.472209632396698\n",
"375 0.45955076813697815\n",
"376 0.4472464919090271\n",
"377 0.4353131949901581\n",
"378 0.4236542582511902\n",
"379 0.41232940554618835\n",
"380 0.40129557251930237\n",
"381 0.39054378867149353\n",
"382 0.38012734055519104\n",
"383 0.3700036108493805\n",
"384 0.36012765765190125\n",
"385 0.35050609707832336\n",
"386 0.3411838710308075\n",
"387 0.3320454955101013\n",
"388 0.32319173216819763\n",
"389 0.3145967721939087\n",
"390 0.3061816692352295\n",
"391 0.29803869128227234\n",
"392 0.29008302092552185\n",
"393 0.28236180543899536\n",
"394 0.2748420238494873\n",
"395 0.26751869916915894\n",
"396 0.26039522886276245\n",
"397 0.2534610629081726\n",
"398 0.24670173227787018\n",
"399 0.2401532679796219\n",
"400 0.23375463485717773\n",
"401 0.2275371253490448\n",
"402 0.22150050103664398\n",
"403 0.2156049609184265\n",
"404 0.20986326038837433\n",
"405 0.20429474115371704\n",
"406 0.1988736242055893\n",
"407 0.19359922409057617\n",
"408 0.18844936788082123\n",
"409 0.1834636926651001\n",
"410 0.17861832678318024\n",
"411 0.17386555671691895\n",
"412 0.1692546308040619\n",
"413 0.16477948427200317\n",
"414 0.160414919257164\n",
"415 0.15614986419677734\n",
"416 0.15201903879642487\n",
"417 0.1480007916688919\n",
"418 0.1440792977809906\n",
"419 0.1402653604745865\n",
"420 0.13657070696353912\n",
"421 0.13295070827007294\n",
"422 0.129446342587471\n",
"423 0.12602251768112183\n",
"424 0.12268824130296707\n",
"425 0.11944673955440521\n",
"426 0.11631718277931213\n",
"427 0.11323467642068863\n",
"428 0.11024248600006104\n",
"429 0.10735208541154861\n",
"430 0.10451424866914749\n",
"431 0.10176849365234375\n",
"432 0.09908266365528107\n",
"433 0.09648296236991882\n",
"434 0.09394434094429016\n",
"435 0.09146330505609512\n",
"436 0.0890551209449768\n",
"437 0.0867248922586441\n",
"438 0.08444181084632874\n",
"439 0.08221925795078278\n",
"440 0.0800599679350853\n",
"441 0.07796868681907654\n",
"442 0.07591632008552551\n",
"443 0.07392437011003494\n",
"444 0.07198865711688995\n",
"445 0.07008222490549088\n",
"446 0.0682484433054924\n",
"447 0.06647884100675583\n",
"448 0.06473026424646378\n",
"449 0.06302813440561295\n",
"450 0.06137767806649208\n",
"451 0.0597919300198555\n",
"452 0.0582134947180748\n",
"453 0.056693755090236664\n",
"454 0.05521111190319061\n",
"455 0.053769730031490326\n",
"456 0.05237339809536934\n",
"457 0.05101470276713371\n",
"458 0.049686357378959656\n",
"459 0.04838278144598007\n",
"460 0.04712178558111191\n",
"461 0.04589943587779999\n",
"462 0.04470691829919815\n",
"463 0.04354717954993248\n",
"464 0.042408015578985214\n",
"465 0.041308339685201645\n",
"466 0.04023618996143341\n",
"467 0.039200473576784134\n",
"468 0.03817952424287796\n",
"469 0.03719625622034073\n",
"470 0.03622708469629288\n",
"471 0.035289373248815536\n",
"472 0.03438589721918106\n",
"473 0.0334918238222599\n",
"474 0.03262655809521675\n",
"475 0.03178119659423828\n",
"476 0.03095570020377636\n",
"477 0.03015051782131195\n",
"478 0.02938341721892357\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"479 0.028624456375837326\n",
"480 0.027895115315914154\n",
"481 0.02717461623251438\n",
"482 0.026480700820684433\n",
"483 0.025793474167585373\n",
"484 0.025134451687335968\n",
"485 0.024496881291270256\n",
"486 0.02386515401303768\n",
"487 0.023256389424204826\n",
"488 0.022666264325380325\n",
"489 0.02207905799150467\n",
"490 0.021514788269996643\n",
"491 0.020962411537766457\n",
"492 0.020433874800801277\n",
"493 0.019914157688617706\n",
"494 0.01941000483930111\n",
"495 0.018911005929112434\n",
"496 0.0184323787689209\n",
"497 0.017968682572245598\n",
"498 0.017506850883364677\n",
"499 0.017064126208424568\n"
]
}
],
"source": [
"import torch\n",
"\n",
"\"\"\"\n",
"A fully-connected ReLU network with one hidden layer and no biases, trained to\n",
"predict y from x by minimizing squared Euclidean distance.\n",
"\n",
"This implementation computes the forward pass using operations on PyTorch\n",
"Tensors, and uses PyTorch autograd to compute gradients.\n",
"\n",
"In this implementation we implement our own custom autograd function to perform\n",
"the ReLU function.\n",
"\"\"\"\n",
"\n",
"class MyReLU(torch.autograd.Function):\n",
" \"\"\"\n",
" We can implement our own custom autograd Functions by subclassing\n",
" torch.autograd.Function and implementing the forward and backward passes\n",
" which operate on Tensors.\n",
" \"\"\"\n",
" @staticmethod\n",
" def forward(ctx, x):\n",
" \"\"\"\n",
" In the forward pass we receive a context object and a Tensor containing the\n",
" input; we must return a Tensor containing the output, and we can use the\n",
" context object to cache objects for use in the backward pass.\n",
" \"\"\"\n",
" ctx.save_for_backward(x)\n",
" return x.clamp(min=0)\n",
"\n",
" @staticmethod\n",
" def backward(ctx, grad_output):\n",
" \"\"\"\n",
" In the backward pass we receive the context object and a Tensor containing\n",
" the gradient of the loss with respect to the output produced during the\n",
" forward pass. We can retrieve cached data from the context object, and must\n",
" compute and return the gradient of the loss with respect to the input to the\n",
" forward function.\n",
" \"\"\"\n",
" x, = ctx.saved_tensors\n",
" grad_x = grad_output.clone()\n",
" grad_x[x < 0] = 0\n",
" return grad_x\n",
"\n",
"\n",
"device = torch.device('cpu')\n",
"# device = torch.device('cuda') # Uncomment this to run on GPU\n",
"\n",
"# N is batch size; D_in is input dimension;\n",
"# H is hidden dimension; D_out is output dimension.\n",
"N, D_in, H, D_out = 64, 1000, 100, 10\n",
"\n",
"# Create random Tensors to hold input and output\n",
"x = torch.randn(N, D_in, device=device)\n",
"y = torch.randn(N, D_out, device=device)\n",
"\n",
"# Create random Tensors for weights.\n",
"w1 = torch.randn(D_in, H, device=device, requires_grad=True)\n",
"w2 = torch.randn(H, D_out, device=device, requires_grad=True)\n",
"\n",
"learning_rate = 1e-6\n",
"for t in range(500):\n",
" # Forward pass: compute predicted y using operations on Tensors; we call our\n",
" # custom ReLU implementation using the MyReLU.apply function\n",
" y_pred = MyReLU.apply(x.mm(w1)).mm(w2)\n",
" \n",
" # Compute and print loss\n",
" loss = (y_pred - y).pow(2).sum()\n",
" print(t, loss.item())\n",
"\n",
" # Use autograd to compute the backward pass.\n",
" loss.backward()\n",
"\n",
" with torch.no_grad():\n",
" # Update weights using gradient descent\n",
" w1 -= learning_rate * w1.grad\n",
" w2 -= learning_rate * w2.grad\n",
"\n",
" # Manually zero the gradients after running the backward pass\n",
" w1.grad.zero_()\n",
" w2.grad.zero_()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Read Later:**\n",
"\n",
"Documentation of ``autograd`` and ``Function`` is at\n",
"https://pytorch.org/docs/autograd\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
"""
From Appendix B in the paper
https://arxiv.org/abs/1806.07366
Neural ODE Solver
Implementation in autograd
"""
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint_all(yt, func, y0, t, func_args, **kwargs):
"""
Extended from "Scalable Inference of Ordinary Differential"
Equation Models of Biochemical Processes". Sec. 2.4.2
Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
https://arxiv.org/pdf/1711.08079.pdf
"""
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Original system augemented with vjp_y, vjp_t and vjp_args
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g, **kwargs):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving current time.
vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
time_vjp_list.append(vjp_cur_t)
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array(t[i], t[i - 1]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output
vjp_y = vjp_y + g[i - 1, :]
time_vjp_list.append(vjp_t0)
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def grad_argnums_wrapper(all_vjp_builder):
"""
A generic autograd helper function. Takes a function that
builds vjps for all arguments, and wraps it to return only required vjps.
"""
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g):
# Return whichever vjps were asked for
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
if __name__ == '__main__':
print(defvjp_argnums(odeint, grad_argnums_wrapper(grad_odeint_all)))
from __future__ import absolute_import
from __future__ import print_function
from builtins import range
import matplotlib.pyplot as plt
import numpy as npo
import autograd.numpy as np
from autograd import grad
#from autograd.scipy.integrate import odeint
from autograd.builtins import tuple
from autograd.misc.optimizers import adam
import autograd.numpy.random as npr
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint(yt, func, y0, t, func_args, **kwargs):
# Extended from "Scalable Inference of Ordinary Differential
# Equation Models of Biochemical Processes", Sec. 2.4.2
# Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
# https://arxiv.org/abs/1711.08079
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Orginal system augmented with vjp_y, vjp_t and vjp_args.
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving measurement time.
vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
time_vjp_list.append(vjp_cur_t)
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation.
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output.
vjp_y = vjp_y + g[i - 1, :]
time_vjp_list.append(vjp_t0)
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def argnums_unpack(all_vjp_builder):
# A generic autograd helper function. Takes a function that
# builds vjps for all arguments, and wraps it to return only required vjps.
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g): # Returns whichever vjps were asked for.
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
defvjp_argnums(odeint, argnums_unpack(grad_odeint))
N = 30 # Dataset size
D = 2 # Data dimension
max_T = 1.5
# Two-dimensional damped oscillator
def func(y, t0, A):
return np.dot(y**3, A)
def nn_predict(inputs, t, params):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.maximum(0, outputs)
return outputs
def init_nn_params(scale, layer_sizes, rs=npr.RandomState(0)):
"""Build a list of (weights, biases) tuples, one for each layer."""
return [(rs.randn(insize, outsize) * scale, # weight matrix
rs.randn(outsize) * scale) # bias vector
for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]
# Define neural ODE model.
def ode_pred(params, y0, t):
return odeint(nn_predict, y0, t, tuple((params,)), rtol=0.01)
def L1_loss(pred, targets):
return np.mean(np.abs(pred - targets))
if __name__ == '__main__':
# Generate data from true dynamics.
true_y0 = np.array([2., 0.]).T
t = np.linspace(0., max_T, N)
true_A = np.array([[-0.1, 2.0], [-2.0, -0.1]])
true_y = odeint(func, true_y0, t, args=(true_A,))
def train_loss(params, iter):
pred = ode_pred(params, true_y0, t)
return L1_loss(pred, true_y)
# Set up figure
fig = plt.figure(figsize=(12, 4), facecolor='white')
ax_traj = fig.add_subplot(131, frameon=False)
ax_phase = fig.add_subplot(132, frameon=False)
ax_vecfield = fig.add_subplot(133, frameon=False)
plt.show(block=False)
# Plots data and learned dynamics.
def callback(params, iter, g):
pred = ode_pred(params, true_y0, t)
print("Iteration {:d} train loss {:.6f}".format(
iter, L1_loss(pred, true_y)))
ax_traj.cla()
ax_traj.set_title('Trajectories')
ax_traj.set_xlabel('t')
ax_traj.set_ylabel('x,y')
ax_traj.plot(t, true_y[:, 0], '-', t, true_y[:, 1], 'g-')
ax_traj.plot(t, pred[:, 0], '--', t, pred[:, 1], 'b--')
ax_traj.set_xlim(t.min(), t.max())
ax_traj.set_ylim(-2, 2)
ax_traj.xaxis.set_ticklabels([])
ax_traj.yaxis.set_ticklabels([])
ax_traj.legend()
ax_phase.cla()
ax_phase.set_title('Phase Portrait')
ax_phase.set_xlabel('x')
ax_phase.set_ylabel('y')
ax_phase.plot(true_y[:, 0], true_y[:, 1], 'g-')
ax_phase.plot(pred[:, 0], pred[:, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)
ax_phase.xaxis.set_ticklabels([])
ax_phase.yaxis.set_ticklabels([])
ax_vecfield.cla()
ax_vecfield.set_title('Learned Vector Field')
ax_vecfield.set_xlabel('x')
ax_vecfield.set_ylabel('y')
ax_vecfield.xaxis.set_ticklabels([])
ax_vecfield.yaxis.set_ticklabels([])
# vector field plot
y, x = npo.mgrid[-2:2:21j, -2:2:21j]
dydt = nn_predict(np.stack([x, y], -1).reshape(21 * 21, 2), 0,
params).reshape(-1, 2)
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)
ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
ax_vecfield.set_xlim(-2, 2)
ax_vecfield.set_ylim(-2, 2)
fig.tight_layout()
plt.draw()
plt.pause(0.001)
# Train neural net dynamics to match data.
init_params = init_nn_params(0.1, layer_sizes=[D, 150, D])
optimized_params = adam(grad(train_loss), init_params,
num_iters=1000, callback=callback)
import tensorflow as tf
import autograd.numpy as np
from autograd import grad
from tensorflow.python.framework import function
rng = np.random.RandomState(42)
x_np = rng.randn(4,4).astype(np.float32)
with tf.device('/cpu:0'):
x = tf.Variable(x_np)
def tf_loss(a):
return tf.reduce_sum(tf.square(a))
def np_loss(a):
return np.array(2.).astype(np.float32)*np.square(a).sum()
grad_np_loss = grad(np_loss)
l = tf_loss(x)
g = tf.gradients(l, x)
with tf.device('/cpu:0'):
np_in_tf = tf.py_func(np_loss, [x], tf.float32)
npgrad_in_tf = tf.py_func(grad_np_loss, [x], tf.float32)
@function.Defun()
def op_grad(x, grad):
return [tf.py_func(grad_np_loss, [x], tf.float32)]
@function.Defun(grad_func=op_grad)
def tf_replaced_grad_loss(a):
return tf_loss(a)
with tf.device('/cpu:0'):
tf_np_grad = tf.gradients(tf_replaced_grad_loss(x),x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("Tensorflow gradient:\n")
print(sess.run(g)[0])
print("\nNumpy gradient (should be 2 times tf version):\n")
print(grad_np_loss(x_np))
print("\nNumpy gradient evaluated in Tensorflow:\n")
print(sess.run(npgrad_in_tf))
print("\nNumpy gradient put in Tensorflow graph:\n")
print(sess.run(tf_np_grad)[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment