Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gyu-don/f5cc025139312ccfd39e48400018118d to your computer and use it in GitHub Desktop.
Save gyu-don/f5cc025139312ccfd39e48400018118d to your computer and use it in GitHub Desktop.
MAKE PYTHON GREAT AGAIN
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# やりたいこと\n",
"計算グラフをぶった切って、オリジナルな微分を定義しよう(いや、そんなのやりたくねぇ。けど仕方ないんだ…)\n",
"\n",
"# 今回やること\n",
"計算グラフをぶった切る関数はいろいろ考えられると思いますが、複雑な関数だと合ってるのか合ってないのか分からないので、今回は単純な関数を文字列にしてevalで評価する関数を用意します。いくら計算が単純でも、さすがにPyTorchはそんなの面倒見きれないので、計算グラフがぶった切られることになります。\n",
"\n",
"また、そうすると自分で微分を定義しないといけないですが、今回は前進差分を使います。\n",
"$$\\frac{\\partial f(x, w)}{\\partial w} = \\frac{f(x, w + \\Delta w) - f(x, w)}{\\Delta w}$$\n",
"\n",
"# 利点\n",
"ないです。モデル作り直しましょう。本当にPyTorch使ってそれやらなきゃいけないのか考えましょう。\n",
"を使うことにします。"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"# どちらも同じですが、f_strは、さすがにこんなの、PyTorchでは、微分を面倒見きれません。\n",
"# 入力は、x, wともPyTorchのTensor型を想定しています。\n",
"def f(x, w):\n",
" return 2 * x * w[0] + x**2 * w[1]\n",
"\n",
"def f_str(x, w):\n",
" return torch.tensor([eval(f'2 * {x_} * {w[0]} + {x_}**2 * {w[1]}') for x_ in x])"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3.], grad_fn=<AddBackward0>)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.tensor([1.])\n",
"w = torch.tensor([1., 1.]).requires_grad_()\n",
"\n",
"f(x, w)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3.])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.tensor([1.])\n",
"w = torch.tensor([1., 1.]).requires_grad_()\n",
"\n",
"f_str(x, w)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`f_str`では、grad_fnがないことに注目"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor([1.])\n",
"w = torch.tensor([1., 1.]).requires_grad_()\n",
"\n",
"y = f(x, w)\n",
"y.backward()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([2., 1.])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w.grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`f`の計算結果は、自動微分される"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-40-9bc7a4e474c9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf_str\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \"\"\"\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m allow_unreachable=True) # allow_unreachable flag\n",
"\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
]
}
],
"source": [
"x = torch.tensor([1.])\n",
"w = torch.tensor([1., 1.]).requires_grad_()\n",
"\n",
"y = f_str(x, w)\n",
"y.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"None\n"
]
}
],
"source": [
"# backwardが計算できなかったので、w.gradはNoneのまま\n",
"print(w.grad)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"# 学習用に、データセットを作ります。\n",
"\n",
"actual_w = 1.2, -3.4\n",
"\n",
"xs = np.random.rand(200).astype(np.float32)\n",
"ys = np.array([f(x, actual_w) for x in xs], dtype=np.float32)\n",
"train_d = torch.utils.data.TensorDataset(torch.from_numpy(xs), torch.from_numpy(ys))\n",
"train_loader = torch.utils.data.DataLoader(train_d, batch_size=10)\n",
"\n",
"v_xs = np.random.rand(10).astype(np.float32)\n",
"v_ys = np.array([f(x, actual_w) for x in v_xs], dtype=np.float32)\n",
"valid_d = torch.utils.data.TensorDataset(torch.from_numpy(v_xs), torch.from_numpy(v_ys))\n",
"valid_loader = torch.utils.data.DataLoader(valid_d, batch_size=1)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"# まずは、普通のやつのモデルを作ってみる\n",
"class Model(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.weight = torch.nn.parameter.Parameter(torch.tensor([0., 0.]))\n",
" \n",
" def forward(self, x):\n",
" return f(x, self.weight)"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(0.1590, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.0512, -0.5521], requires_grad=True)\n",
"1 tensor(0.1147, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.2418, -1.0564], requires_grad=True)\n",
"2 tensor(0.0806, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.4203, -1.5229], requires_grad=True)\n",
"3 tensor(0.0569, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.5757, -1.9350], requires_grad=True)\n",
"4 tensor(0.0397, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.7089, -2.2835], requires_grad=True)\n",
"5 tensor(0.0267, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.8218, -2.5669], requires_grad=True)\n",
"6 tensor(0.0169, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.9162, -2.7894], requires_grad=True)\n",
"7 tensor(0.0097, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.9938, -2.9584], requires_grad=True)\n",
"8 tensor(0.0049, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.0559, -3.0836], requires_grad=True)\n",
"9 tensor(0.0021, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1033, -3.1755], requires_grad=True)\n",
"10 tensor(0.0007, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1373, -3.2434], requires_grad=True)\n",
"11 tensor(0.0002, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1600, -3.2940], requires_grad=True)\n",
"12 tensor(8.7163e-05, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1746, -3.3311], requires_grad=True)\n",
"13 tensor(3.2799e-05, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1842, -3.3569], requires_grad=True)\n",
"14 tensor(1.1896e-05, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1905, -3.3739], requires_grad=True)\n",
"15 tensor(3.8930e-06, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1945, -3.3847], requires_grad=True)\n",
"16 tensor(1.1996e-06, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1969, -3.3913], requires_grad=True)\n",
"17 tensor(3.5391e-07, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1983, -3.3952], requires_grad=True)\n",
"18 tensor(9.6183e-08, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1991, -3.3975], requires_grad=True)\n",
"19 tensor(2.4172e-08, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1996, -3.3987], requires_grad=True)\n"
]
}
],
"source": [
"model = Model()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
"criterion = torch.nn.MSELoss()\n",
"loss_hist = []\n",
"model.train()\n",
"for epoch in range(20):\n",
" for i, (xs, l) in enumerate(train_loader):\n",
" out = model(xs)\n",
" loss = criterion(out, l)\n",
" loss_hist.append(loss)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(epoch, loss, model.weight)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f98b076a160>]"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(loss_hist)"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(1.9578e-10, grad_fn=<MseLossBackward>)\n",
"tensor(7.9192e-09, grad_fn=<MseLossBackward>)\n",
"tensor(1.8404e-08, grad_fn=<MseLossBackward>)\n",
"tensor(1.5630e-08, grad_fn=<MseLossBackward>)\n",
"tensor(3.5687e-09, grad_fn=<MseLossBackward>)\n",
"tensor(1.4232e-08, grad_fn=<MseLossBackward>)\n",
"tensor(2.1366e-09, grad_fn=<MseLossBackward>)\n",
"tensor(1.2179e-08, grad_fn=<MseLossBackward>)\n",
"tensor(2.3878e-08, grad_fn=<MseLossBackward>)\n",
"tensor(1.9628e-08, grad_fn=<MseLossBackward>)\n"
]
}
],
"source": [
"for x, y in valid_loader:\n",
" pred = model(x)\n",
" print(criterion(pred, y))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"だいぶいい感じ。ノイズも何もないから、そりゃそう。"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"# f_strに変えるだけではダメなことを見ていく\n",
"class Model2(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.weight = torch.nn.parameter.Parameter(torch.tensor([0., 0.]))\n",
" \n",
" def forward(self, x):\n",
" return f_str(x, self.weight)"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-121-194d051407ff>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mloss_hist2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \"\"\"\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m allow_unreachable=True) # allow_unreachable flag\n",
"\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
]
}
],
"source": [
"model2 = Model2()\n",
"optimizer = torch.optim.Adam(model2.parameters(), lr=0.1)\n",
"criterion = torch.nn.MSELoss()\n",
"loss_hist2 = []\n",
"model2.train()\n",
"for epoch in range(20):\n",
" for i, (xs, l) in enumerate(train_loader):\n",
" out = model2(xs)\n",
" loss = criterion(out, l)\n",
" loss_hist2.append(loss)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(epoch, loss, model2.weight)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"続いて、微分を定義していく。\n",
"Functionを使うこと、staticmethodとしてforward, backwardを定義することはドキュメント通り。ドキュメントになかった部分を補足していく。\n",
"\n",
"- `ctx.save_for_backward`でテンソルを保存できるが、この方法では`torch.Tensor`以外は保存できない。けれど、`ctx.なんちゃら = ...`の形で保存することができ、これは`backward`で使うことが出来る\n",
" - [Pytorch内部でも使われているテクニック](https://github.com/pytorch/pytorch/blob/master/torch/autograd/_functions/tensor.py#L11)なので、おそらく、使ってもいいんじゃないか\n",
"- `backward`で返す値は、`forward`の引数に対応している。`forward`の引数からctxを除いたものの微分結果を返していく\n",
" - 微分が必要ないもの(テンソルじゃないものや、`required_grad=True`じゃないテンソル)に対応している箇所はNoneを返せばいい\n",
"- 入力がテンソル${\\bf w} = [w_0, w_1, ..., w_{n-1}]$の場合、返す値は$$[\\sum_i\\mathrm{grad\\_output}_i\\frac{\\partial f(x_i, {\\bf w})}{\\partial w_0}, \\sum_i \\mathrm{grad\\_output}_i\\frac{\\partial f(x_i, {\\bf w})}{\\partial w_1}, ... \\sum_i\\mathrm{grad\\_output}_i\\frac{\\partial f(x_i, {\\bf w})}{\\partial w_{n-1}}]$$となる。ただし、$\\sum_i$は、入力xがミニバッチ$[x_0, x_1, ...]$で来た場合に、各々の結果を足し合わせることを言っている。`grad_output`の次元はミニバッチの大きさに対応しているので、このように結果に掛け合わせる。"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"# Functionを使う。staticmethodを使う。そのへんは、ドキュメント通り\n",
"class GeneralFunctionWithForwardDifference(torch.autograd.Function):\n",
" @staticmethod\n",
" def forward(ctx, f, xs, weight):\n",
" ys = f(xs, weight)\n",
" ctx.save_for_backward(xs, ys, weight)\n",
" ctx.f = f # 実はctxにも何かを保存できて、backwardで使える\n",
" return ys\n",
" \n",
" @staticmethod\n",
" def backward(ctx, grad_output):\n",
" xs, ys, weight = ctx.saved_tensors\n",
" f = ctx.f\n",
" dw = 0.001\n",
" diff = []\n",
" weight = weight.detach() # weightに余計な計算履歴を残さないために、detachする。\n",
" for i in range(len(weight)):\n",
" weight[i] += dw\n",
" diff.append(torch.sum(grad_output * (f(xs, weight) - ys)))\n",
" weight[i] -= dw\n",
" diff = torch.tensor(diff) / dw\n",
" return None, None, diff"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"class Model2(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.weight = torch.nn.parameter.Parameter(torch.tensor([0., 0.]))\n",
" \n",
" def forward(self, x):\n",
" # 書くのが若干めんどくさい。\n",
" return GeneralFunctionWithForwardDifference.apply(f_str, x, self.weight)"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(0.1590, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.0512, -0.5521], requires_grad=True)\n",
"1 tensor(0.1147, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.2418, -1.0564], requires_grad=True)\n",
"2 tensor(0.0806, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.4203, -1.5229], requires_grad=True)\n",
"3 tensor(0.0569, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.5757, -1.9350], requires_grad=True)\n",
"4 tensor(0.0397, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.7089, -2.2835], requires_grad=True)\n",
"5 tensor(0.0267, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.8217, -2.5669], requires_grad=True)\n",
"6 tensor(0.0169, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.9161, -2.7894], requires_grad=True)\n",
"7 tensor(0.0097, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 0.9938, -2.9584], requires_grad=True)\n",
"8 tensor(0.0049, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.0559, -3.0836], requires_grad=True)\n",
"9 tensor(0.0021, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1033, -3.1755], requires_grad=True)\n",
"10 tensor(0.0007, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1372, -3.2434], requires_grad=True)\n",
"11 tensor(0.0002, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1600, -3.2940], requires_grad=True)\n",
"12 tensor(8.7188e-05, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1746, -3.3310], requires_grad=True)\n",
"13 tensor(3.2809e-05, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1842, -3.3569], requires_grad=True)\n",
"14 tensor(1.1901e-05, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1905, -3.3739], requires_grad=True)\n",
"15 tensor(3.8952e-06, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1945, -3.3847], requires_grad=True)\n",
"16 tensor(1.1996e-06, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1969, -3.3913], requires_grad=True)\n",
"17 tensor(3.5389e-07, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1983, -3.3952], requires_grad=True)\n",
"18 tensor(9.6174e-08, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1991, -3.3975], requires_grad=True)\n",
"19 tensor(2.4199e-08, grad_fn=<MseLossBackward>) Parameter containing:\n",
"tensor([ 1.1996, -3.3987], requires_grad=True)\n"
]
}
],
"source": [
"model2 = Model2()\n",
"optimizer = torch.optim.Adam(model2.parameters(), lr=0.1)\n",
"criterion = torch.nn.MSELoss()\n",
"loss_hist2 = []\n",
"model2.train()\n",
"for epoch in range(20):\n",
" for i, (xs, l) in enumerate(train_loader):\n",
" out = model2(xs)\n",
" loss = criterion(out, l)\n",
" loss_hist2.append(loss)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(epoch, loss, model2.weight)"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f98950c3400>]"
]
},
"execution_count": 130,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(loss_hist2)"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(1.9318e-10, grad_fn=<MseLossBackward>)\n",
"tensor(7.9086e-09, grad_fn=<MseLossBackward>)\n",
"tensor(1.8396e-08, grad_fn=<MseLossBackward>)\n",
"tensor(1.5653e-08, grad_fn=<MseLossBackward>)\n",
"tensor(3.5723e-09, grad_fn=<MseLossBackward>)\n",
"tensor(1.4225e-08, grad_fn=<MseLossBackward>)\n",
"tensor(2.1283e-09, grad_fn=<MseLossBackward>)\n",
"tensor(1.2192e-08, grad_fn=<MseLossBackward>)\n",
"tensor(2.3896e-08, grad_fn=<MseLossBackward>)\n",
"tensor(1.9645e-08, grad_fn=<MseLossBackward>)\n"
]
}
],
"source": [
"for x, y in valid_loader:\n",
" pred = model2(x)\n",
" print(criterion(pred, y))"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f989501af10>"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# lossを重ねて書いてみる\n",
"plt.plot(loss_hist, label=\"f\")\n",
"plt.plot(loss_hist2, label=\"f_str\")\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([9.3132e-08], grad_fn=<SubBackward0>)\n",
"tensor([5.9605e-08], grad_fn=<SubBackward0>)\n",
"tensor([2.9802e-08], grad_fn=<SubBackward0>)\n",
"tensor([-8.9407e-08], grad_fn=<SubBackward0>)\n",
"tensor([-2.9802e-08], grad_fn=<SubBackward0>)\n",
"tensor([2.9802e-08], grad_fn=<SubBackward0>)\n",
"tensor([8.9407e-08], grad_fn=<SubBackward0>)\n",
"tensor([-5.9605e-08], grad_fn=<SubBackward0>)\n",
"tensor([-5.9605e-08], grad_fn=<SubBackward0>)\n",
"tensor([-5.9605e-08], grad_fn=<SubBackward0>)\n"
]
}
],
"source": [
"# predの差は? → どれもほぼゼロなので、どっちのモデルもほぼ同じと言える\n",
"for x, y in valid_loader:\n",
" diff = model2(x) - model(x)\n",
" print(diff)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment