-
-
Save tttamaki/07709e726b60b4cb9e85b6351a9e7d31 to your computer and use it in GitHub Desktop.
Pytorchの「.detach()」と「with no_grad():」と「.requires_grad = False」の違い
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
}, | |
"orig_nbformat": 4, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3.9.6 64-bit" | |
}, | |
"interpreter": { | |
"hash": "4cd7ab41f5fca4b9b44701077e38c5ffd31fe66a6cab21e0214b68d958d0e462" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[tensor(2., requires_grad=True), tensor(3., requires_grad=True), tensor(5., requires_grad=True), tensor(7., requires_grad=True), tensor(11., requires_grad=True), tensor(13., requires_grad=True), tensor(17., requires_grad=True)]\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as numpy\n", | |
"import torch\n", | |
"\n", | |
"a = torch.tensor(2., requires_grad=True)\n", | |
"b = torch.tensor(3., requires_grad=True)\n", | |
"c = torch.tensor(5., requires_grad=True)\n", | |
"d = torch.tensor(7., requires_grad=True)\n", | |
"e = torch.tensor(11., requires_grad=True)\n", | |
"f = torch.tensor(13., requires_grad=True)\n", | |
"\n", | |
"x = torch.tensor(17., requires_grad=True)\n", | |
"\n", | |
"params = [a, b, c, d, e, f, x]\n", | |
"\n", | |
"print(params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(11.)\nc.grad=ey=407.0: tensor(407.)\nb.grad=ey=55.0: tensor(55.)\na.grad=xey=935.0: tensor(935.)\nx.grad=aey=110.0: tensor(110.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"y = a * x + b\n", | |
"z = c * y + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(11.)\nc.grad=ey=407.0: tensor(407.)\nb.grad=ey=55.0: tensor(0.)\na.grad=xey=935.0: tensor(0.)\nx.grad=aey=110.0: tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"y = a * x + b\n", | |
"z = c * y.detach() + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(0.)\nc.grad=ey=407.0: tensor(0.)\nb.grad=ey=55.0: tensor(0.)\na.grad=xey=935.0: tensor(0.)\nx.grad=aey=110.0: tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"y = a * x + b\n", | |
"z = c * y + d\n", | |
"w = e * z.detach() + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(0.)\nc.grad=ey=407.0: tensor(0.)\nb.grad=ey=55.0: tensor(0.)\na.grad=xey=935.0: tensor(0.)\nx.grad=aey=110.0: tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"y = a * x + b\n", | |
"with torch.no_grad():\n", | |
" z = c * y + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\n", | |
"f.grad=1: tensor(1.)\n", | |
"e.grad=z=192.0: tensor(192.)\n", | |
"d.grad=e=11.0: tensor(11.)\n", | |
"c.grad=ey=407.0: tensor(407.)\n", | |
"b.grad=ey=55.0: tensor(0.)\n", | |
"a.grad=xey=935.0: tensor(0.)\n", | |
"x.grad=aey=110.0: tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" y = a * x + b\n", | |
"z = c * y + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(0.)\nc.grad=ey=407.0: tensor(0.)\nb.grad=ey=55.0: tensor(55.)\na.grad=xey=935.0: tensor(935.)\nx.grad=aey=110.0: tensor(110.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"c.requires_grad = False\n", | |
"d.requires_grad = False\n", | |
"\n", | |
"y = a * x + b\n", | |
"z = c * y + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)\n", | |
"\n", | |
"c.requires_grad = True\n", | |
"d.requires_grad = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(0.)\ne.grad=z=192.0: tensor(0.)\nd.grad=e=11.0: tensor(11.)\nc.grad=ey=407.0: tensor(407.)\nb.grad=ey=55.0: tensor(55.)\na.grad=xey=935.0: tensor(935.)\nx.grad=aey=110.0: tensor(110.)\n" | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"e.requires_grad = False\n", | |
"f.requires_grad = False\n", | |
"\n", | |
"y = a * x + b\n", | |
"z = c * y + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)\n", | |
"\n", | |
"e.requires_grad = True\n", | |
"f.requires_grad = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "RuntimeError", | |
"evalue": "you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m/tmp/ipykernel_8243/2375783274.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach()." | |
] | |
} | |
], | |
"source": [ | |
"for p in params:\n", | |
" if p.grad:\n", | |
" p.grad.data = torch.tensor(0.0) # zero_grad\n", | |
"\n", | |
"y = a * x + b\n", | |
"y.requires_grad = False\n", | |
"z = c * y + d\n", | |
"w = e * z + f\n", | |
"w.backward()\n", | |
"\n", | |
"print('y={}, z={}, w={}'.format(y, z, w))\n", | |
"print('f.grad=1:', f.grad)\n", | |
"print('e.grad=z={}:'.format(z.item()), e.grad)\n", | |
"print('d.grad=e={}:'.format(e.item()), d.grad)\n", | |
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n", | |
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n", | |
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n", | |
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)\n", | |
"\n" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment