Skip to content

Instantly share code, notes, and snippets.

@suzusuzu
Last active October 7, 2019 16:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzusuzu/17471b63cb112988dc953b201ddd9fc5 to your computer and use it in GitHub Desktop.
Save suzusuzu/17471b63cb112988dc953b201ddd9fc5 to your computer and use it in GitHub Desktop.
pytorch_gradient_descent.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "pytorch_gradient_descent.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/suzusuzu/17471b63cb112988dc953b201ddd9fc5/pytorch_gradient_descent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RdGgtliP4I8Y",
"colab_type": "code",
"outputId": "31b05d3e-8bc8-47b1-ee66-f38376bb25ae",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"! pip install -U -q torch\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"# 評価関数\n",
"def _sphere(x):\n",
" return torch.pow(x, 2).sum()\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self, dim=100, func=None):\n",
" super(Model, self).__init__()\n",
" if func is None:\n",
" func = _sphere\n",
" self.func = func\n",
" # 設計変数\n",
" self.x = nn.Parameter(torch.rand(dim))\n",
"\n",
" def forward(self):\n",
" return self.func(self.x)\n",
" \n",
" def vars(self):\n",
" return self.x.detach().numpy()\n",
"\n",
" def objective(self):\n",
" with torch.no_grad():\n",
" return self.func(self.x).numpy()\n",
"\n",
"# 次元は10000000次元\n",
"model = Model(dim=10000000)\n",
"optimizer = optim.SGD(model.parameters(), lr=0.1)\n",
"\n",
"N = 100\n",
"\n",
"print('初期の評価値', model.objective())\n",
"\n",
"print('optimization...')\n",
"for i in range(N):\n",
" output = model()\n",
" optimizer.zero_grad()\n",
" loss = output\n",
" loss.backward()\n",
" optimizer.step()\n",
" print('\\rloss:', loss.item(), end='')\n",
"print()\n",
"\n",
"print('評価値', model.objective())"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"初期の評価値 3333505.5\n",
"optimization...\n",
"loss: 2.1613312810482566e-13\n",
"評価値 1.383244e-13\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment