Skip to content

Instantly share code, notes, and snippets.

@douglasrizzo
Created July 2, 2020 07:05
Show Gist options
  • Save douglasrizzo/1418883a8779563967bea80ef699d2a9 to your computer and use it in GitHub Desktop.
Save douglasrizzo/1418883a8779563967bea80ef699d2a9 to your computer and use it in GitHub Desktop.
wandb_pyg.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "wandb_pyg.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyPYhkkDHWwY+bJDaClh3zcG",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/douglasrizzo/1418883a8779563967bea80ef699d2a9/wandb_pyg.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Q3WTkk3rRmJp",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "7ce99522-b69d-45ec-9ff2-ff5bdc4a9353"
},
"source": [
"!pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"!pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"!pip install torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"!pip install torch-spline-conv==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"!pip install torch_geometric wandb"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"Collecting torch-scatter==latest+cu101\n",
" Using cached https://pytorch-geometric.com/whl/torch-1.5.0/torch_scatter-latest%2Bcu101-cp36-cp36m-linux_x86_64.whl\n",
"Installing collected packages: torch-scatter\n",
" Found existing installation: torch-scatter 2.0.5\n",
" Uninstalling torch-scatter-2.0.5:\n",
" Successfully uninstalled torch-scatter-2.0.5\n",
"Successfully installed torch-scatter-2.0.5\n",
"Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"Collecting torch-sparse==latest+cu101\n",
" Using cached https://pytorch-geometric.com/whl/torch-1.5.0/torch_sparse-latest%2Bcu101-cp36-cp36m-linux_x86_64.whl\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from torch-sparse==latest+cu101) (1.4.1)\n",
"Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from scipy->torch-sparse==latest+cu101) (1.18.5)\n",
"Installing collected packages: torch-sparse\n",
" Found existing installation: torch-sparse 0.6.6\n",
" Uninstalling torch-sparse-0.6.6:\n",
" Successfully uninstalled torch-sparse-0.6.6\n",
"Successfully installed torch-sparse-0.6.6\n",
"Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"Collecting torch-cluster==latest+cu101\n",
" Using cached https://pytorch-geometric.com/whl/torch-1.5.0/torch_cluster-latest%2Bcu101-cp36-cp36m-linux_x86_64.whl\n",
"Installing collected packages: torch-cluster\n",
" Found existing installation: torch-cluster 1.5.5\n",
" Uninstalling torch-cluster-1.5.5:\n",
" Successfully uninstalled torch-cluster-1.5.5\n",
"Successfully installed torch-cluster-1.5.5\n",
"Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0.html\n",
"Collecting torch-spline-conv==latest+cu101\n",
" Using cached https://pytorch-geometric.com/whl/torch-1.5.0/torch_spline_conv-latest%2Bcu101-cp36-cp36m-linux_x86_64.whl\n",
"Installing collected packages: torch-spline-conv\n",
" Found existing installation: torch-spline-conv 1.2.0\n",
" Uninstalling torch-spline-conv-1.2.0:\n",
" Successfully uninstalled torch-spline-conv-1.2.0\n",
"Successfully installed torch-spline-conv-1.2.0\n",
"Requirement already satisfied: torch_geometric in /usr/local/lib/python3.6/dist-packages (1.5.0)\n",
"Collecting wandb\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/00/8e/d43984196a0fa8ef961ae3dce91ada52ae7747fbf39d41f5743c27152d97/wandb-0.9.2-py2.py3-none-any.whl (1.4MB)\n",
"\u001b[K |████████████████████████████████| 1.4MB 2.9MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (4.41.1)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (1.4.1)\n",
"Requirement already satisfied: plyfile in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (0.7.2)\n",
"Requirement already satisfied: googledrivedownloader in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (0.4)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (2.23.0)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (1.5.1+cu101)\n",
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (2.10.0)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (0.22.2.post1)\n",
"Requirement already satisfied: ase in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (3.19.1)\n",
"Requirement already satisfied: rdflib in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (5.0.0)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (1.0.5)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (1.18.5)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (2.4)\n",
"Requirement already satisfied: numba in /usr/local/lib/python3.6/dist-packages (from torch_geometric) (0.48.0)\n",
"Collecting shortuuid>=0.5.0\n",
" Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n",
"Collecting GitPython>=1.0.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8c/f9/c315aa88e51fabdc08e91b333cfefb255aff04a2ee96d632c32cb19180c9/GitPython-3.1.3-py3-none-any.whl (451kB)\n",
"\u001b[K |████████████████████████████████| 460kB 15.8MB/s \n",
"\u001b[?25hRequirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n",
"Collecting watchdog>=0.8.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/0e/06/121302598a4fc01aca942d937f4a2c33430b7181137b35758913a8db10ad/watchdog-0.10.3.tar.gz (94kB)\n",
"\u001b[K |████████████████████████████████| 102kB 9.0MB/s \n",
"\u001b[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n",
"Collecting configparser>=3.8.1\n",
" Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n",
"Collecting subprocess32>=3.5.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n",
"\u001b[K |████████████████████████████████| 102kB 9.1MB/s \n",
"\u001b[?25hRequirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n",
"Collecting gql==0.2.0\n",
" Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n",
"Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n",
"Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n",
"Collecting docker-pycreds>=0.4.0\n",
" Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n",
"Requirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n",
"Collecting sentry-sdk>=0.4.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/2f/6b/939519d77c95a9b2c85b771e9dccbf9e69cb90016c7cd63887c26400dd7a/sentry_sdk-0.15.1-py2.py3-none-any.whl (105kB)\n",
"\u001b[K |████████████████████████████████| 112kB 20.9MB/s \n",
"\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torch_geometric) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torch_geometric) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torch_geometric) (2.9)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torch_geometric) (2020.6.20)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torch_geometric) (0.16.0)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->torch_geometric) (0.15.1)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from ase->torch_geometric) (3.2.2)\n",
"Requirement already satisfied: pyparsing in /usr/local/lib/python3.6/dist-packages (from rdflib->torch_geometric) (2.4.7)\n",
"Requirement already satisfied: isodate in /usr/local/lib/python3.6/dist-packages (from rdflib->torch_geometric) (0.6.0)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->torch_geometric) (2018.9)\n",
"Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->torch_geometric) (4.4.2)\n",
"Requirement already satisfied: llvmlite<0.32.0,>=0.31.0dev0 in /usr/local/lib/python3.6/dist-packages (from numba->torch_geometric) (0.31.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from numba->torch_geometric) (47.3.1)\n",
"Collecting gitdb<5,>=4.0.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
"\u001b[K |████████████████████████████████| 71kB 7.0MB/s \n",
"\u001b[?25hCollecting pathtools>=0.1.1\n",
" Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n",
"Collecting graphql-core<2,>=0.5.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n",
"\u001b[K |████████████████████████████████| 71kB 7.0MB/s \n",
"\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->ase->torch_geometric) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->ase->torch_geometric) (1.2.0)\n",
"Collecting smmap<4,>=3.0.1\n",
" Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
"Building wheels for collected packages: watchdog, subprocess32, gql, pathtools, graphql-core\n",
" Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for watchdog: filename=watchdog-0.10.3-cp36-none-any.whl size=73870 sha256=aed2a78037d7a144f89b77552527b21c91943daf3a807bba6ee7483c109ca6a0\n",
" Stored in directory: /root/.cache/pip/wheels/a8/1d/38/2c19bb311f67cc7b4d07a2ec5ea36ab1a0a0ea50db994a5bc7\n",
" Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=b30ea432ac9852ecafd83e11d7cc3432eece93e4b279b4d694b8b226031560fd\n",
" Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n",
" Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=530392ea89ce48265f79feb0003f499a3fe3807be05febe811edca9ac6a017eb\n",
" Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n",
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=69b22ceed08304a10654b1ba4dcf8ee2246cee34b84b13adecdb40ce24655b4a\n",
" Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n",
" Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=ca58e3af1cedf4bf507a5bf047a919eae7426a3e965794959c4586ccec292efb\n",
" Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n",
"Successfully built watchdog subprocess32 gql pathtools graphql-core\n",
"Installing collected packages: shortuuid, smmap, gitdb, GitPython, pathtools, watchdog, configparser, subprocess32, graphql-core, gql, docker-pycreds, sentry-sdk, wandb\n",
"Successfully installed GitPython-3.1.3 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.15.1 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.9.2 watchdog-0.10.3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ASH4q8m5RYMl",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 307
},
"outputId": "b876b3b4-f29d-44d0-da9f-ea2fc9b1a374"
},
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"from torch_geometric.datasets import Planetoid\n",
"from torch_geometric.nn import GCNConv\n",
"\n",
"import wandb\n",
"\n",
"\n",
"class Net(torch.nn.Module):\n",
"\n",
" def __init__(self, in_feats, out_classes):\n",
" super(Net, self).__init__()\n",
" self.conv1 = GCNConv(in_feats, 16)\n",
" self.conv2 = GCNConv(16, out_classes)\n",
"\n",
" def forward(self, data):\n",
" x, edge_index = data.x, data.edge_index\n",
"\n",
" x = self.conv1(x, edge_index)\n",
" x = F.relu(x)\n",
" x = F.dropout(x, training=self.training)\n",
" x = self.conv2(x, edge_index)\n",
"\n",
" return F.log_softmax(x, dim=1)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" wandb.init(project='my_test', group='graph_nets')\n",
"\n",
" dataset = Planetoid(root='/tmp/Cora', name='Cora')\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
" net = Net(dataset.num_node_features, dataset.num_classes).to(device)\n",
" wandb.watch(net)\n",
"\n",
" optimizer = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)\n",
" data = dataset[0].to(device)\n",
"\n",
" net.train()\n",
" for _ in range(100):\n",
" optimizer.zero_grad()\n",
" out = net(data)\n",
" loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" net.eval()\n",
" _, pred = net(data).max(dim=1)\n",
" correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())\n",
" acc = correct / data.test_mask.sum().item()\n",
" print('Accuracy: {:.4f}'.format(acc))\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" window._wandbApiKey = new Promise((resolve, reject) => {\n",
" function loadScript(url) {\n",
" return new Promise(function(resolve, reject) {\n",
" let newScript = document.createElement(\"script\");\n",
" newScript.onerror = reject;\n",
" newScript.onload = resolve;\n",
" document.body.appendChild(newScript);\n",
" newScript.src = url;\n",
" });\n",
" }\n",
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
" const iframe = document.createElement('iframe')\n",
" iframe.style.cssText = \"width:0;height:0;border:none\"\n",
" document.body.appendChild(iframe)\n",
" const handshake = new Postmate({\n",
" container: iframe,\n",
" url: 'https://app.wandb.ai/authorize'\n",
" });\n",
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
" handshake.then(function(child) {\n",
" child.on('authorize', data => {\n",
" clearTimeout(timeout)\n",
" resolve(data)\n",
" });\n",
" });\n",
" })\n",
" });\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://app.wandb.ai/tetamusha/my_test\" target=\"_blank\">https://app.wandb.ai/tetamusha/my_test</a><br/>\n",
" Run page: <a href=\"https://app.wandb.ai/tetamusha/my_test/runs/3ehw3yjc\" target=\"_blank\">https://app.wandb.ai/tetamusha/my_test/runs/3ehw3yjc</a><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
"Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n",
"Processing...\n",
"Done!\n",
"Accuracy: 0.8040\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "24jclWBrTnrw",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment