Skip to content

Instantly share code, notes, and snippets.

@anand086
Last active July 22, 2021 06:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anand086/0aac6c5644b94599d7df646602a71e11 to your computer and use it in GitHub Desktop.
Save anand086/0aac6c5644b94599d7df646602a71e11 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"!pip install -Uqq fastbook\n",
"import fastbook\n",
"fastbook.setup_book()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from fastai.vision.all import *\n",
"from fastbook import *\n",
"\n",
"matplotlib.rc('image', cmap='Greys')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.48.0 (20210717.2242)\n",
" -->\n",
"<!-- Title: G Pages: 1 -->\n",
"<svg width=\"661pt\" height=\"78pt\"\n",
" viewBox=\"0.00 0.00 660.87 78.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 74)\">\n",
"<title>G</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-74 656.87,-74 656.87,4 -4,4\"/>\n",
"<!-- init -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>init</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">init</text>\n",
"</g>\n",
"<!-- predict -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>predict</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"135.2\" cy=\"-18\" rx=\"44.39\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"135.2\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">predict</text>\n",
"</g>\n",
"<!-- init&#45;&gt;predict -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>init&#45;&gt;predict</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M54.25,-18C62.37,-18 71.63,-18 80.89,-18\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"80.89,-21.5 90.89,-18 80.89,-14.5 80.89,-21.5\"/>\n",
"</g>\n",
"<!-- loss -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>loss</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"244.99\" cy=\"-52\" rx=\"28.7\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"244.99\" y=\"-48.3\" font-family=\"Times,serif\" font-size=\"14.00\">loss</text>\n",
"</g>\n",
"<!-- predict&#45;&gt;loss -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>predict&#45;&gt;loss</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M170.6,-28.85C183.05,-32.78 197.09,-37.21 209.54,-41.14\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"208.53,-44.49 219.12,-44.16 210.64,-37.81 208.53,-44.49\"/>\n",
"</g>\n",
"<!-- gradient -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>gradient</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"406.63\" cy=\"-52\" rx=\"50.09\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"406.63\" y=\"-48.3\" font-family=\"Times,serif\" font-size=\"14.00\">gradient</text>\n",
"</g>\n",
"<!-- loss&#45;&gt;gradient -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>loss&#45;&gt;gradient</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M273.8,-52C293.82,-52 321.57,-52 346.45,-52\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"346.55,-55.5 356.55,-52 346.55,-48.5 346.55,-55.5\"/>\n",
"</g>\n",
"<!-- step -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>step</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"524.23\" cy=\"-18\" rx=\"30.59\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"524.23\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">step</text>\n",
"</g>\n",
"<!-- gradient&#45;&gt;step -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>gradient&#45;&gt;step</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M445.8,-40.77C459.01,-36.89 473.76,-32.55 486.82,-28.71\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"487.82,-32.06 496.43,-25.88 485.85,-25.35 487.82,-32.06\"/>\n",
"</g>\n",
"<!-- step&#45;&gt;predict -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>step&#45;&gt;predict</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M493.68,-18C428.65,-18 272.39,-18 189.67,-18\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"189.47,-14.5 179.47,-18 189.47,-21.5 189.47,-14.5\"/>\n",
"<text text-anchor=\"middle\" x=\"315.09\" y=\"-21.8\" font-family=\"Times,serif\" font-size=\"14.00\">repeat</text>\n",
"</g>\n",
"<!-- stop -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>stop</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"622.32\" cy=\"-18\" rx=\"30.59\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"622.32\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">stop</text>\n",
"</g>\n",
"<!-- step&#45;&gt;stop -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>step&#45;&gt;stop</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M554.84,-18C563.24,-18 572.53,-18 581.44,-18\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"581.64,-21.5 591.64,-18 581.64,-14.5 581.64,-21.5\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x7f8d64ce64e0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gv('''\n",
"init->predict->loss->gradient->step->stop\n",
"step->predict[label=repeat]\n",
"''')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def f(x): return x**2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec2-user/SageMaker/.env/fastai/lib/python3.6/site-packages/fastbook/__init__.py:73: UserWarning: Not providing a value for linspace's steps is deprecated and will throw a runtime error in a future release. This warning will appear only once per process. (Triggered internally at /pytorch/aten/src/ATen/native/RangeFactories.cpp:25.)\n",
" x = torch.linspace(min,max)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(f, 'x', 'x**2')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_function(f, 'x', 'x**2')\n",
"plt.scatter(-1.5, f(-1.5), color='red');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculating Gradients"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# let's pick a tensor value at which we want gradient\n",
"xt = tensor(4.).requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(4., requires_grad=True)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xt"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(16., grad_fn=<PowBackward0>)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# calculate function with that value\n",
"yt = f(xt)\n",
"yt"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# tell pytorch to calculate the gradient\n",
"yt.backward()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(8.)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# view the gradient by checking \"grad\" attribute of tensor\n",
"# the derivative of x**2 is 2*x and so value is 8\n",
"xt.grad"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# same steps with vector argument for function \n",
"xt = tensor([3.,4.,10.]).requires_grad_()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 3., 4., 10.], requires_grad=True)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xt"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# define function to return scalar value\n",
"def f(x): \n",
" return (x**2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(125., grad_fn=<SumBackward0>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# calculate function and return scalar value\n",
"# which in this case will be 125 (3**2 + 4**2 + 10**2 = 9 + 16 + 100 = 125)\n",
"yt = f(xt)\n",
"yt"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# tell pytorch to calculate the gradient\n",
"yt.backward()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 6., 8., 20.])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# view the gradient by checking \"grad\" attribute of tensor\n",
"xt.grad"
]
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "fastai",
"language": "python",
"name": "fastai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment