Skip to content

Instantly share code, notes, and snippets.

@douglasrizzo
Created July 2, 2020 07:05
Show Gist options
  • Save douglasrizzo/81b9b2a190cf5ad3125e929df919e98d to your computer and use it in GitHub Desktop.
Save douglasrizzo/81b9b2a190cf5ad3125e929df919e98d to your computer and use it in GitHub Desktop.
wandb_dict_forked_net.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "wandb_dict_forked_net.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyNuid6KzwgJ8c3G6u45upsB",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/douglasrizzo/81b9b2a190cf5ad3125e929df919e98d/wandb_dict_forked_net.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fA_qOdGNTHl1",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "f8786ab7-6b8a-4e7d-bc31-f3dff4c564b3"
},
"source": [
"!pip install wandb"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting wandb\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/00/8e/d43984196a0fa8ef961ae3dce91ada52ae7747fbf39d41f5743c27152d97/wandb-0.9.2-py2.py3-none-any.whl (1.4MB)\n",
"\u001b[K |████████████████████████████████| 1.4MB 2.8MB/s \n",
"\u001b[?25hRequirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n",
"Collecting watchdog>=0.8.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/0e/06/121302598a4fc01aca942d937f4a2c33430b7181137b35758913a8db10ad/watchdog-0.10.3.tar.gz (94kB)\n",
"\u001b[K |████████████████████████████████| 102kB 9.2MB/s \n",
"\u001b[?25hRequirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n",
"Collecting sentry-sdk>=0.4.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/2f/6b/939519d77c95a9b2c85b771e9dccbf9e69cb90016c7cd63887c26400dd7a/sentry_sdk-0.15.1-py2.py3-none-any.whl (105kB)\n",
"\u001b[K |████████████████████████████████| 112kB 15.4MB/s \n",
"\u001b[?25hCollecting gql==0.2.0\n",
" Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n",
"Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n",
"Requirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n",
"Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n",
"Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n",
"Collecting subprocess32>=3.5.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n",
"\u001b[K |████████████████████████████████| 102kB 7.4MB/s \n",
"\u001b[?25hCollecting GitPython>=1.0.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8c/f9/c315aa88e51fabdc08e91b333cfefb255aff04a2ee96d632c32cb19180c9/GitPython-3.1.3-py3-none-any.whl (451kB)\n",
"\u001b[K |████████████████████████████████| 460kB 16.4MB/s \n",
"\u001b[?25hCollecting docker-pycreds>=0.4.0\n",
" Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n",
"Collecting configparser>=3.8.1\n",
" Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n",
"Collecting shortuuid>=0.5.0\n",
" Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n",
"Requirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.23.0)\n",
"Collecting pathtools>=0.1.1\n",
" Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (2020.6.20)\n",
"Requirement already satisfied: urllib3>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (1.24.3)\n",
"Collecting graphql-core<2,>=0.5.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n",
"\u001b[K |████████████████████████████████| 71kB 8.0MB/s \n",
"\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n",
"Collecting gitdb<5,>=4.0.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
"\u001b[K |████████████████████████████████| 71kB 8.2MB/s \n",
"\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.9)\n",
"Collecting smmap<4,>=3.0.1\n",
" Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
"Building wheels for collected packages: watchdog, gql, subprocess32, pathtools, graphql-core\n",
" Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for watchdog: filename=watchdog-0.10.3-cp36-none-any.whl size=73870 sha256=68178c701148e2175ac6faf8c5909a881e107661898e6d457de37319b66841ff\n",
" Stored in directory: /root/.cache/pip/wheels/a8/1d/38/2c19bb311f67cc7b4d07a2ec5ea36ab1a0a0ea50db994a5bc7\n",
" Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=d71a26c76744d5bdffdcdebe8f71449eb39bcb3932a2928fb2a86d915fed0f00\n",
" Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n",
" Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=29029465143372e3cb022d6356009abb6a9b65cf4d567475867070fa9f22b931\n",
" Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n",
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=55f43fa983289fd5e10c32146ac088a36e2020da8c04609919b8604ef443747b\n",
" Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n",
" Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=b8b4027b6eeabf627fd99fac272ce464f1cf80607423854d5541f9509582a3d8\n",
" Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n",
"Successfully built watchdog gql subprocess32 pathtools graphql-core\n",
"Installing collected packages: pathtools, watchdog, sentry-sdk, graphql-core, gql, subprocess32, smmap, gitdb, GitPython, docker-pycreds, configparser, shortuuid, wandb\n",
"Successfully installed GitPython-3.1.3 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.15.1 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.9.2 watchdog-0.10.3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2tFg7xkyR6CN",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 162
},
"outputId": "281c4e2e-d346-4917-860a-b61ea1aa5a0a"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch import optim\n",
"\n",
"import wandb\n",
"\n",
"\n",
"def num_flat_features(x):\n",
" size = x.size()[1:] # all dimensions except the batch dimension\n",
" num_features = 1\n",
" for s in size:\n",
" num_features *= s\n",
" return num_features\n",
"\n",
"\n",
"class Net1(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(Net1, self).__init__()\n",
" # 1 input image channel, 6 output channels, 3x3 square convolution\n",
" # kernel\n",
" self.conv1 = nn.Conv2d(1, 6, 3)\n",
" self.conv2 = nn.Conv2d(6, 16, 3)\n",
" # an affine operation: y = Wx + b\n",
" self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" # Max pooling over a (2, 2) window\n",
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
" # If the size is a square you can only specify a single number\n",
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
" x = x.view(-1, num_flat_features(x))\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
"\n",
" x_dict = {'a': x[0, 5:], 'b': x[0, :5]}\n",
" return x_dict\n",
"\n",
"\n",
"class Net2(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" # 1 input image channel, 6 output channels, 3x3 square convolution\n",
" # kernel\n",
" self.conv1 = nn.Conv2d(1, 6, 3)\n",
" self.conv2 = nn.Conv2d(6, 16, 3)\n",
" # an affine operation: y = Wx + b\n",
" self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3_1 = nn.Linear(84, 5)\n",
" self.fc3_2 = nn.Linear(84, 5)\n",
"\n",
" def forward(self, x):\n",
" # Max pooling over a (2, 2) window\n",
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
" # If the size is a square you can only specify a single number\n",
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
" x = x.view(-1, num_flat_features(x))\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
"\n",
" return {'a': self.fc3_1(x), 'b': self.fc3_2(x)}\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" wandb.init(project='my_test', group='conv_nets')\n",
"\n",
" # net = Net1()\n",
" net = Net2()\n",
" wandb.watch(net)\n",
"\n",
" optimizer = optim.SGD(net.parameters(), lr=0.01)\n",
" criterion = nn.MSELoss()\n",
"\n",
" for _ in range(100):\n",
" in_feats = torch.randn(1, 1, 32, 32)\n",
" target = {'a': torch.randn(5), 'b': torch.randn(5)}\n",
"\n",
" optimizer.zero_grad() # zero the gradient buffers\n",
" output = net(in_feats)\n",
"\n",
" loss = torch.zeros(1)\n",
" for key in output:\n",
" loss = criterion(output[key], target[key])\n",
"\n",
" wandb.log({'loss': loss})\n",
" loss.backward()\n",
" optimizer.step() # Does the update\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" window._wandbApiKey = new Promise((resolve, reject) => {\n",
" function loadScript(url) {\n",
" return new Promise(function(resolve, reject) {\n",
" let newScript = document.createElement(\"script\");\n",
" newScript.onerror = reject;\n",
" newScript.onload = resolve;\n",
" document.body.appendChild(newScript);\n",
" newScript.src = url;\n",
" });\n",
" }\n",
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
" const iframe = document.createElement('iframe')\n",
" iframe.style.cssText = \"width:0;height:0;border:none\"\n",
" document.body.appendChild(iframe)\n",
" const handshake = new Postmate({\n",
" container: iframe,\n",
" url: 'https://app.wandb.ai/authorize'\n",
" });\n",
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
" handshake.then(function(child) {\n",
" child.on('authorize', data => {\n",
" clearTimeout(timeout)\n",
" resolve(data)\n",
" });\n",
" });\n",
" })\n",
" });\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://app.wandb.ai/tetamusha/my_test\" target=\"_blank\">https://app.wandb.ai/tetamusha/my_test</a><br/>\n",
" Run page: <a href=\"https://app.wandb.ai/tetamusha/my_test/runs/336leocg\" target=\"_blank\">https://app.wandb.ai/tetamusha/my_test/runs/336leocg</a><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:432: UserWarning: Using a target size (torch.Size([5])) that is different to the input size (torch.Size([1, 5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n"
],
"name": "stderr"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment