Skip to content

Instantly share code, notes, and snippets.

@tttamaki
Created July 14, 2021 09:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tttamaki/07709e726b60b4cb9e85b6351a9e7d31 to your computer and use it in GitHub Desktop.
Save tttamaki/07709e726b60b4cb9e85b6351a9e7d31 to your computer and use it in GitHub Desktop.
Pytorchの「.detach()」と「with no_grad():」と「.requires_grad = False」の違い
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.6 64-bit"
},
"interpreter": {
"hash": "4cd7ab41f5fca4b9b44701077e38c5ffd31fe66a6cab21e0214b68d958d0e462"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[tensor(2., requires_grad=True), tensor(3., requires_grad=True), tensor(5., requires_grad=True), tensor(7., requires_grad=True), tensor(11., requires_grad=True), tensor(13., requires_grad=True), tensor(17., requires_grad=True)]\n"
]
}
],
"source": [
"import numpy as numpy\n",
"import torch\n",
"\n",
"a = torch.tensor(2., requires_grad=True)\n",
"b = torch.tensor(3., requires_grad=True)\n",
"c = torch.tensor(5., requires_grad=True)\n",
"d = torch.tensor(7., requires_grad=True)\n",
"e = torch.tensor(11., requires_grad=True)\n",
"f = torch.tensor(13., requires_grad=True)\n",
"\n",
"x = torch.tensor(17., requires_grad=True)\n",
"\n",
"params = [a, b, c, d, e, f, x]\n",
"\n",
"print(params)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(11.)\nc.grad=ey=407.0: tensor(407.)\nb.grad=ey=55.0: tensor(55.)\na.grad=xey=935.0: tensor(935.)\nx.grad=aey=110.0: tensor(110.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"y = a * x + b\n",
"z = c * y + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(11.)\nc.grad=ey=407.0: tensor(407.)\nb.grad=ey=55.0: tensor(0.)\na.grad=xey=935.0: tensor(0.)\nx.grad=aey=110.0: tensor(0.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"y = a * x + b\n",
"z = c * y.detach() + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(0.)\nc.grad=ey=407.0: tensor(0.)\nb.grad=ey=55.0: tensor(0.)\na.grad=xey=935.0: tensor(0.)\nx.grad=aey=110.0: tensor(0.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"y = a * x + b\n",
"z = c * y + d\n",
"w = e * z.detach() + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(0.)\nc.grad=ey=407.0: tensor(0.)\nb.grad=ey=55.0: tensor(0.)\na.grad=xey=935.0: tensor(0.)\nx.grad=aey=110.0: tensor(0.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"y = a * x + b\n",
"with torch.no_grad():\n",
" z = c * y + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\n",
"f.grad=1: tensor(1.)\n",
"e.grad=z=192.0: tensor(192.)\n",
"d.grad=e=11.0: tensor(11.)\n",
"c.grad=ey=407.0: tensor(407.)\n",
"b.grad=ey=55.0: tensor(0.)\n",
"a.grad=xey=935.0: tensor(0.)\n",
"x.grad=aey=110.0: tensor(0.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"with torch.no_grad():\n",
" y = a * x + b\n",
"z = c * y + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(1.)\ne.grad=z=192.0: tensor(192.)\nd.grad=e=11.0: tensor(0.)\nc.grad=ey=407.0: tensor(0.)\nb.grad=ey=55.0: tensor(55.)\na.grad=xey=935.0: tensor(935.)\nx.grad=aey=110.0: tensor(110.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"c.requires_grad = False\n",
"d.requires_grad = False\n",
"\n",
"y = a * x + b\n",
"z = c * y + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)\n",
"\n",
"c.requires_grad = True\n",
"d.requires_grad = True"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"y=37.0, z=192.0, w=2125.0\nf.grad=1: tensor(0.)\ne.grad=z=192.0: tensor(0.)\nd.grad=e=11.0: tensor(11.)\nc.grad=ey=407.0: tensor(407.)\nb.grad=ey=55.0: tensor(55.)\na.grad=xey=935.0: tensor(935.)\nx.grad=aey=110.0: tensor(110.)\n"
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"e.requires_grad = False\n",
"f.requires_grad = False\n",
"\n",
"y = a * x + b\n",
"z = c * y + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)\n",
"\n",
"e.requires_grad = True\n",
"f.requires_grad = True"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_8243/2375783274.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach()."
]
}
],
"source": [
"for p in params:\n",
" if p.grad:\n",
" p.grad.data = torch.tensor(0.0) # zero_grad\n",
"\n",
"y = a * x + b\n",
"y.requires_grad = False\n",
"z = c * y + d\n",
"w = e * z + f\n",
"w.backward()\n",
"\n",
"print('y={}, z={}, w={}'.format(y, z, w))\n",
"print('f.grad=1:', f.grad)\n",
"print('e.grad=z={}:'.format(z.item()), e.grad)\n",
"print('d.grad=e={}:'.format(e.item()), d.grad)\n",
"print('c.grad=ey={}:'.format(e.item() * y.item()), c.grad)\n",
"print('b.grad=ey={}:'.format(e.item() * c.item()), b.grad)\n",
"print('a.grad=xey={}:'.format(x.item() * e.item() * c.item()), a.grad)\n",
"print('x.grad=aey={}:'.format(a.item() * e.item() * c.item()), x.grad)\n",
"\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment