Skip to content

Instantly share code, notes, and snippets.

@k1ochiai
Created May 19, 2019 12:59
Show Gist options
  • Save k1ochiai/cd0279ca79dd74e91a2b5e1187928adb to your computer and use it in GitHub Desktop.
Save k1ochiai/cd0279ca79dd74e91a2b5e1187928adb to your computer and use it in GitHub Desktop.
DGL sample
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# reference: https://docs.dgl.ai/tutorials/basics/1_first.html"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import dgl\n",
"\n",
"def build_sample_graph():\n",
" g = dgl.DGLGraph()\n",
" g.add_nodes(5)\n",
" edge_list = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3)]\n",
" src, dst = tuple(zip(*edge_list))\n",
" g.add_edges(src, dst)\n",
" g.add_edges(dst, src)\n",
"\n",
" return g"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"We have 5 nodes.\n",
"We have 12 edges.\n"
]
}
],
"source": [
"G = build_sample_graph()\n",
"print('We have %d nodes.' % G.number_of_nodes())\n",
"print('We have %d edges.' % G.number_of_edges())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcUAAAE1CAYAAACWU/udAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzt3X1czff/P/BHHSlaUSrpykVNQnWi\nWMJMrmYzn82GUyEylzMfbhtj320+M9cXU4yxuYr4zMyYi20sLNe5KJWahKhEmK4Wpzrn/P74pN/a\nSaTOeZ2Lx/12c7utc07nPPp88PB8vd/v19tEpVKpQERERDAVHYCIiEhXsBSJiIgqsBSJiIgqsBSJ\niIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgq\nsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqNBAdgIhIH8jlcmRlZaGoqAhlZWUwMzODlZUVXF1dYW5u\nLjoe1RMTlUqlEh2CiEhX5efnIyMjA3l5eQAApVJZ+Zyp6f8W2xwcHODh4YGmTZsKyUj1h6VIRPQE\nmZmZSEtLg0KheOprJRIJvLy80KpVK80HI43h8ikRUTUyMzORmppaZTKsiUKhQFpaGgCwGPUYS5GI\n6B/y8/ORlpZWpRD37duH2NhYZGZmomfPnpg2bZra9z0uxqZNm3IpVU/x7FMion/IyMhQWzK1tbXF\n0KFD0bdv3xq/V6FQICMjQ5PxSIM4KRIR/Y1cLq88qebvunXrBuB/hSmXy2t8j7y8PMjlcp6Vqoc4\nKRIR/U1WVla9vE92dna9vA9pF0uRiOhvioqKnvnkmidRKpUoLCysp0SkTSxFIqK/KSsr06n3Ie3i\nMUUiMnrFxcVITk5GYmIiTExM4OLiUuf3NDMzq4dkpG0sRSIyKrdv30ZiYmLlr4SEBGRlZaFDhw6Q\nSqXo2bNntd+nUCigUCigVCqhVCpRWloKiUQCiUSi9lpTU1NYW1tr+kchDeCONkRkkBQKBa5cuVKl\nABMTE1FeXg6pVFr5y8/PD56enmjQ4H8zglwuR2xsrNpxxW3btmH79u1VHpPJZAgJCVH7bFNTUwQH\nB/PsUz3EUiQivVdSUlK5/Pn4V3JyMhwdHasUoFQqhbOzM0xMTGp8v3PnzuH27dvPncfR0RH+/v7P\n/f0kDpdPiUiv3LlzR236u3HjBry8vCqLLyQkBD4+PmjSpMlzfYaHhwfu3r37THue/pNEIoGHh8dz\nfS6Jx0mRiHSSUqlERkZGlWN/iYmJkMvl8PPzqzL9tWvXrt5PbKnNZuCPcVNw/cdSJCLhSkpKkJKS\norb8aW9vr7b86erq+tTlz/qSmZmJlJQUKJXKyttEPQkL0TCwFIlIq+7evat29mdmZiY8PT2rTIA+\nPj7CN9VWqVQYMmQIRo4ciYYNGwLg/RQNHUuRiDRCqVTi6tWrasf/SkpK1KY/Ly+vytLRJUePHsW7\n776LP/74A+Xl5cjOzkZhYSHKyspgZmYGa2truLi48CxTA8JSJKI6e/jwIS5dulSl/JKSkmBra6t2\n/M/NzU1ry5911a9fPwwbNgwRERGio5CWsBSJqFbu3bunNv1du3YNbdu2rVJ+vr6+sLGxER33uZ09\nexZDhgxBRkaGTk6xpBksRSKqllKpxPXr19UKsKioCL6+vlUmQC8vL4NbQnzzzTfxyiuv4P333xcd\nhbSIpUhEePTokdry58WLF2FjY6N2/K9Vq1Z6s/z5vC5duoTg4GBcu3YNjRs3Fh2HtIilSGRk7t+/\nj4sXL1YpwIyMDHh4eFRue/Z4+dPW1lZ0XCHCwsLQoUMHzJo1S3QU0jKWIpGBUqlU1S5/FhQUwNfX\nt8r01759e1hYWIiOrBOuXbuGLl264OrVq8+9Iw7pL5YikQGQy+VITU1VW/60srKqsvH14+XPp12I\nbszGjx8Pe3t7fPHFF6KjkAAsRSI98+eff6otf165cgXu7u5qZ3/a2dmJjqtXcnJy4O3tjfT0dP5v\nZ6RYikQ6SqVS4caNG2rLn3/++Sd8fHyqTIAdOnTg8mc9mD59OlQqFb788kvRUUgQliKRDigtLa12\n+dPS0lLt7M82bdpw+VMD7t27h7Zt2yI5ORnOzs6i45AgvHUUkZbl5+erLX9evnwZrVu3riy+1157\nDVKpFPb29qLjGo3IyEi8/fbbLEQjx0mRSENUKhWysrIqb3n0+Ne9e/eqLH9KpVJ07NgRjRo1Eh3Z\naBUWFqJNmzY4c+YM3N3dRcchgViKRPWgrKwMaWlpasf/LCws1JY/PTw8uPypYxYuXIjk5GTExMSI\njkKCsRSJaqmgoEBt+fOPP/5Ay5Ytq2x95uvri+bNm4uOS09RUlKCNm3a4LfffkPHjh1FxyHBWIpE\nT6BSqZCdna02/d25cwfe3t5qy5+WlpaiI9NzWLlyJWJjY7F7927RUUgHsBSJ8L/lz8uXL6sd/2vQ\noAH8/PyqTIAeHh6QSCSiI1M9KC0thYeHB3bu3IkuXbqIjkM6gGefktEpLCxEUlJSlfJLTU2Fm5tb\nZfF9+OGHkEqlcHR0FB2XNGjr1q3w9PRkIVIlTopksFQqFW7duoXExMQqE2Bubi46duxYZeszb29v\nLn8aGYVCAS8vL6xduxavvPKK6DikI1iKZBDKy8tx+fJlteN/JiYmand+f/HFF9GgARdJjN2OHTuw\nYsUKnDhxwuBvhUXPjqVIeqe4uBhJSUlVpr9Lly7BxcVFbfNrR0dH/oVHalQqFfz8/PDFF1/g9ddf\nFx2HdAj/uUw6S6VSITc3V236y8nJQYcOHSrLb/To0fD29oaVlZXoyKQnDhw4AJVKhddee010FNIx\nnBQNmFwuR1ZWFoqKilBWVgYzMzNYWVnB1dUV5ubmouNVoVAokJ6ernb8T6lUVln+9PPzQ9u2bbn8\nSc9NpVIhKCgIU6dOxbBhw0THIR3DUjRA+fn5yMjIQF5eHgBAqVRWPvd4JxUHBwd4eHigadOmWs9X\nXFyM5OTkKtNfSkoKnJyc1HZ/cXJy4vIn1aujR49i3LhxSEtL46U1pIalaGAyMzORlpYGhULx1NdK\nJBJ4eXmhVatWGstz+/ZttekvKysL7du3r1J+Pj4+sLa21lgOosf69u0LmUyGMWPGiI5COoilaEAy\nMzORmppaZTJ8mvoqRoVCgStXrqgd/ysrK1M7+9PT0xNmZmZ1+jyi5xEfH4+3334bGRkZaNiwoeg4\npINYigYiPz8fp06dqjIhlpWVYc2aNUhMTERxcTFatGiBESNGwN/fv8r3SiQSBAYGPvNSaklJSeXy\n5+MJMCUlBc2bN1db/nRxceHyJ+mMf/3rXwgODsaUKVNERyEdxVI0EOfOncPt27erPPbo0SPs2rUL\nwcHBsLe3x7lz57B06VKsXLlSbaNqR0dHtbIEgDt37qhNfzdu3ICXl5fa8meTJk00+jMS1UVKSgr6\n9OmD69ev8zZd9EQsRQMgl8sRGxv7TMumU6ZMwfDhwxEUFFTlcVNTU7Rq1QopKSlVjv/J5XK16c/L\ny4vLn6R3QkND4e3tjY8++kh0FNJhPK/dAGRlZT3T6x48eICcnBy4ubmpPffw4UMsX74cf/75J6RS\nKSZOnAipVApXV1cuf5Leu3r1Kn799VesWbNGdBTScSxFA1BUVPTUKbG8vBzLli1D79694erqqva8\nubk5Jk+eDD8/P03FJBJm0aJFmDhxIs9wpqdiKRqAsrKyGp9XKpVYvnw5GjRogAkTJjz3+xDpo5yc\nHOzcuRPp6emio5AeMBUdgOqupuN7KpUKUVFRyM/Px6xZs2rcCYbHCckQLV26FOHh4bCzsxMdhfQA\nJ0UDYGVlBVNT02qXUFevXo3s7GzMnTu3xq3dTE1NubREBufu3bvYvHkzkpOTRUchPcFSNACurq7V\nLg3l5eXhl19+gZmZGUaOHFn5+OTJk9GrVy+117u4uGgyJpHWRUVF4Z133oGzs7PoKKQneEmGgaju\nOsXaeNJ1ikT6qqCgAO7u7oiPj0ebNm1ExyE9wWOKBsLDw+O5NzcuLS3lxcxkcNasWYMBAwawEKlW\nWIoGomnTpvDy8qr1NYWmpqbIz89Hr169sHXrVg2lI9KukpISrFixghfqU62xFA2Is7Mz9u7d+8wb\ngkskErRv3x5jxozBoUOHMG/ePIwYMQKFhYUaTkqkWevXr0dgYCA6duwoOgrpGZaiAVm0aBGuX7+O\nHj16wNHREaamppX3T3zs8WOOjo4IDAysvDuGVCrF+fPn0bhxY/j5+eH06dMCfgKiuistLcWSJUsw\ne/Zs0VFID/FEGwNx8eJF9OnTBxcuXKjcsUYulyM7OxuFhYUoKyuDmZkZrK2t4eLiUuPlGbt27cLE\niRMxdepUzJw5kzdiJb2yYcMG/Pe//8XBgwdFRyE9xFI0AKWlpejSpQumTp2K0aNH18t7ZmdnIyws\nDCYmJtiyZQsv1yC9oFAo4OXlhXXr1lV72RHR03D51ADMmzcPzs7OCA8Pr7f3dHFxQWxsLPr27YvO\nnTtj165d9fbeRJqyc+dO2Nvb4+WXXxYdhfQUJ0U9d+HCBQwYMACJiYlwcnLSyGecOXMGISEh6NOn\nD7788ks0btxYI59DVBcqlQpSqRTz58/Ha6+9JjoO6SlOinpMLpdj1KhRWL58ucYKEQC6du2KhIQE\nlJSUoHPnzkhMTNTYZxE9r/3798PExAQDBw4UHYX0GEtRj33++edwd3dHaGioxj/L2toaW7Zswccf\nf4y+fftixYoV4CID6QqVSoV58+Zh9uzZvP8n1QmXT/XU2bNn8frrr+PixYtwdHTU6mdfvXoVoaGh\nsLW1xcaNG9G8eXOtfj7RPx05cgQTJkxAamoqz5amOuGkqIcePXqEUaNGITIyUuuFCADu7u44duwY\nOnXqBD8/P/zyyy9az0D0d/PmzcNHH33EQqQ646Soh2bMmIFr167h+++/F75UdPToUYwcORJvv/02\nFixYUOP1j0SacObMGQwdOhQZGRm8JyjVGSdFPXPy5Els2bIFa9asEV6IANCrVy8kJiYiMzMTXbt2\nRVpamuhIZGQWLFiADz/8kIVI9YKlqEdKSkoQHh6OVatWwd7eXnScSra2tvjhhx8wadIk9OzZE998\n8w1PwiGtSE5OxpkzZxARESE6ChkILp/qkenTpyM3Nxfbt28XHeWJ0tLSIJPJ4O7ujm+++Qa2trai\nI5EBCw0NhY+PD2bOnCk6ChkITop64tixY/jvf/+LVatWiY5SIy8vL5w+fRpubm6QSqX4/fffRUci\nA5WRkYGDBw9i4sSJoqOQAeGkqAf++usv+Pr6YtmyZRg8eLDoOM/s559/RkREBMaMGYPPPvuMx3yo\nXo0bNw6Ojo74/PPPRUchA8JS1ANTpkxBfn4+tmzZIjpKrd25cwfh4eHIz89HTEwM74JO9SI7Oxs+\nPj5IT0+HnZ2d6DhkQLh8quOOHDmCH3/8EVFRUaKjPJfmzZtj//79GDZsGLp27YqYmBjRkcgALFu2\nDKNHj2YhUr3jpKjDioqK4OPjg1WrVhnEBscJCQkICQlBQEAAVq1aBWtra9GRSA/dvXsXnp6eSElJ\n0eiev2ScOCnqsBkzZuCVV14xiEIEAD8/P5w7dw6NGjWCn58fzpw5IzoS6aHIyEgMHTqUhUgawUlR\nRx06dAgRERFISkpC06ZNRcepd4+va/z3v/+NGTNmcHsueiYFBQVwd3dHfHw8j0+TRrAUdVBhYSG8\nvb2xbt069O/fX3QcjcnKysKIESNgYmKCLVu2wMXFRXQk0nELFixAamqqXp50RvqBpaiDxo4dC1NT\nU6xbt050FI1TKBRYuHAhoqKi8PXXX+PNN98UHYl0VElJCdq0aYPY2Fh06NBBdBwyUCxFHfPzzz9j\n4sSJSEpKMqoTUU6fPo2QkBD069cPy5cvR+PGjUVHIh0TFRWFo0ePYteuXaKjkAHjiTY65MGDBxg3\nbhzWr19vVIUIAC+99BISExNRXFwMf39/XLx4UXQk0iGlpaVYsmQJZs+eLToKGTiWog6ZNm0a3njj\nDQQHB4uOIoS1tTW2bt2K2bNno0+fPoiMjOTG4gQA2LJlC9q3bw9/f3/RUcjAcflUR+zduxdTp05F\nUlISXnjhBdFxhLt69SpCQkLQrFkzbNq0CQ4ODqIjkSAKhQLt2rXDt99+i5dffll0HDJwnBR1wJ9/\n/okJEyZg48aNLMQK7u7uOH78OPz8/CCVSvHLL7+IjkSCfP/992jevDl69uwpOgoZAU6KOiAsLAzN\nmjVDZGSk6Cg66ejRoxgxYgTeeecdLFiwAObm5qIjkZaoVCr4+vpi4cKFGDhwoOg4ZAQ4KQr2448/\n4syZM5g/f77oKDqrV69eSExMxPXr1/HSSy/hjz/+EB2JtGTfvn2QSCR49dVXRUchI8FSFOju3buY\nNGkSNm3aBEtLS9FxdFqzZs2wa9cuTJw4ET169MA333zDk3AMnEqlwrx58zB79myYmJiIjkNGgsun\nAg0bNgyurq5YunSp6Ch6JTU1FTKZDC+++CLWrVsHW1tb0ZFIAw4fPoxJkybh0qVL3AaQtIaToiA7\nduzAxYsXMXfuXNFR9E779u1x5swZuLq6QiqV4vfffxcdiTRg/vz5+Oijj1iIpFWcFAW4c+cOfH19\nsXv3brz00kui4+i1AwcOICIiAmPHjsWnn34KMzMz0ZGoHpw5cwZDhw5FRkYG/z8lrWIpaplKpcKQ\nIUPQtm1bLFy4UHQcg3D79m2Eh4ejoKAAMTExvHuCARg8eDD69euHyZMni45CRobLp1q2fft2XL58\nGXPmzBEdxWA4OjriwIEDGDp0KLp27Ypt27aJjkR1kJycjPj4eIwZM0Z0FDJCnBS1KDc3F76+vti/\nfz8CAgJExzFICQkJkMlk6NKlC1atWmV0e8gagpCQEEilUsyYMUN0FDJCnBS1RKVSYfz48Rg3bhwL\nUYP8/Pxw/vx5WFhYoFOnToiPjxcdiWohIyMDhw4dwoQJE0RHISPFUtSSLVu24MaNG/j0009FRzF4\nlpaWWLduHRYtWoRBgwZhwYIFUCgUomPRM1i0aBEmTZrECZ+E4fKpFuTk5MDPzw+//vor/Pz8RMcx\nKllZWQgLC4NEIkF0dDRcXFxER6InyMrKgq+vL65cuYJmzZqJjkNGipOihqlUKrz77ruYPHkyC1EA\nV1dXHD58GMHBwejcuTN2794tOhI9wbJlyzBmzBgWIgnFSVHDNmzYgJUrVyI+Pp7XWwl26tQphIaG\nol+/fli+fDkaN24sOhJVyMvLQ7t27ZCSkgInJyfRcciIcVLUoJs3b2LmzJnYvHkzC1EHBAYGIiEh\nAUVFRfD398fFixdFR6IKkZGRGDZsGAuRhOOkqCEqlQr9+vVDr1698PHHH4uOQ/+wZcsWTJ8+Hf/3\nf/+H999/nxtOC1RQUAB3d3ecPXsWrVu3Fh2HjBxLUUPWrl2Lb7/9FqdOnUKDBg1Ex6FqXL16FSEh\nIWjWrBk2bdoEBwcH0ZGM0vz58/HHH38gOjpadBQilqImXL9+HQEBAYiLi0P79u1Fx6EalJWV4bPP\nPsOmTZuwceNG9O/fX3Qko1JSUoLWrVvjyJEj/LNCOoGlWM+USiX69OmDAQMGcEcOPXLkyBGMHDkS\nQ4cOxfz582Fubi46klGIjIxEXFwcfvjhB9FRiACwFOvdV199ha1bt+L48eO85Y2euX//PsaOHYvM\nzExs374d7dq1Ex3JoJWWlsLd3R27d+9G586dRcchAsCzT+vV1atXK5fiWIj6p1mzZti1axcmTJiA\nHj164NtvvwX/zag50dHR6NChAwuRdAonxXqiVCrRq1cv/Otf/8L06dNFx6E6Sk1NhUwmw4svvoh1\n69bB1tZWdCSDUl5ejnbt2mHDhg3o2bOn6DhElTgp1pOoqCgolUpMnTpVdBSqB+3bt8eZM2fg4uIC\nqVSKuLg40ZEMyvfffw9HR0f06NFDdBSiKjgp1oP09HR069YNp0+fhoeHh+g4VM8OHDiAiIgIjB07\nFp9++ik3YqgjpVIJX19fLF68GK+++qroOERVcFKsI4VCgfDwcHz22WcsRAM1cOBAJCQkID4+Hj17\n9sT169dFR9Jr+/btg5mZGQYMGCA6CpEalmIdffnll2jYsCEmT54sOgppkKOjI37++We888476NKl\nC7Zt2yY6kl5SqVSYN28eZs+ezV2ESCdx+bQO0tLS0KNHD8THx6NNmzai45CWJCQkQCaToUuXLvjq\nq69gZWUlOpLeiI2NxeTJk3Hp0iWeoU06iZPicyovL8eoUaMwd+5cFqKR8fPzw/nz52Fubg4/Pz/E\nx8eLjqQ35s+fj1mzZrEQSWexFJ/TkiVLYG1tjfHjx4uOQgJYWlrim2++wcKFC/H6669jwYIFUCgU\nomPptNOnT1fuN0ukq7h8+hySk5PRu3dvnDt3Di1bthQdhwTLyspCWFgYJBIJtmzZAmdnZ9GRdNIb\nb7yBAQMGYNKkSaKjED0RJ8VaKisrQ3h4OBYsWMBCJACAq6srDh8+jN69e6NTp07YvXu36Eg6Jykp\nCWfPnsXo0aNFRyGqESfFWpo7dy5OnDiBn3/+mWfPkZpTp04hNDQU/fv3x7Jly9C4cWPRkXSCTCZD\np06d8OGHH4qOQlQjlmItJCYmom/fvkhISICLi4voOKSjCgoKMHHiRFy8eBHbt2+Hj4+P6EhCXbly\nBd26dcO1a9d4pi7pPC6fPqPS0lKEh4djyZIlLESqUZMmTRATE4OZM2ciODgYUVFRRr2x+KJFizB5\n8mQWIukFTorP6NNPP8WFCxewd+9eLpvSM8vIyEBISAjs7e2xceNGODg4iI6kVVlZWfD19cWVK1fQ\nrFkz0XGInoqT4jM4f/48vv76a6xbt46FSLXi4eGBEydOwMfHB1KpFL/++qvoSFq1dOlSREREsBBJ\nb3BSfAq5XI7OnTtj1qxZCA0NFR2H9Njhw4cxatQoDB06FPPnz4e5ubnoSBqVl5eHdu3a4dKlS2jR\nooXoOETPhJPiU/znP//Biy++yAuOqc569+6NxMREXL16FYGBgbh8+bLoSBq1YsUKDB8+nIVIeoWT\nYg3i4+MxaNAgJCUloXnz5qLjkIFQqVRYu3Yt/u///g8LFy5ERESEwS3L5+fnw93dHefOnUPr1q1F\nxyF6ZizFJ3j06BH8/PwwZ84cDBs2THQcMkCpqakYPnw4PD09sW7dOtjY2IiOVG/mzZuHy5cvIzo6\nWnQUolrh8ukTfPLJJ+jYsSOGDh0qOgoZqPbt2yM+Ph5OTk6QSqU4duyY6Ej14q+//kJUVBRmzZol\nOgpRrXFSrMbJkycxZMgQJCUlwd7eXnQcMgL79+/H2LFj8e677+LTTz9FgwYNREd6bitWrMCxY8fw\nww8/iI5CVGssxX8oKSmBVCrFggULMGTIENFxyIjk5uYiPDwcRUVFiImJ0ctjcXK5HO7u7tizZw86\nd+4sOg5RrXH59B8+/vhj+Pv7sxBJ61q0aIGff/4Zb7/9Nrp27Yrt27eLjlRr0dHR6NixIwuR9BYn\nxb+Ji4vD8OHDkZyczIuNSagLFy5AJpPhpZdewqpVq/Rii7Ty8nK0a9cOGzduRI8ePUTHIXounBQr\n/PXXXxg9ejTWrFnDQiThOnXqhAsXLsDMzAx+fn6Ij48XHempvv/+e7Ro0YKFSHqNk2KFKVOmoKCg\ngKeQk875/vvvMXnyZEyfPh0zZsyAqanu/VtWqVTC19cXS5YswYABA0THIXpuuvenS4DDhw/jxx9/\nRGRkpOgoRGreeecdnDt3DgcOHEDfvn2Rk5MjOpKaffv2oWHDhujfv7/oKER1YvSlWFRUhIiICIO7\neJoMi5ubG44cOYJevXqhU6dO2LNnj+hIlVQqFebNm4fZs2cb3M48ZHyMfvl0woQJKCsrw/r160VH\nIXomJ0+eRGhoKF599VUsXboUjRs3FponNjYW7733Hi5duqSTS7tEtWHUv4MPHjyIAwcOYPny5aKj\nED2zbt26ITExEQ8ePEBAQACSkpKE5pk3bx5mzZrFQiSDYLS/iwsKCjB27Fh8++23aNKkieg4RLXS\npEkTbNu2DTNnzkRwcDBWrlwJEYs+p06dwvXr1yGTybT+2USaYLTLp2PHjoVEIsHatWtFRyGqk4yM\nDMhkMjRv3hwbN27U6taEgwYNwsCBAzFx4kStfSaRJhnlpHjgwAH89ttvWLp0qegoRHXm4eGBEydO\nwNvbG1KpFAcPHtTK5168eBHnz5/H6NGjtfJ5RNpgdJPigwcP4O3tjejoaPTu3Vt0HKJ6FRsbi1Gj\nRmH48OGYP38+GjZsqLHPGj58OPz9/fHBBx9o7DOItM3oSnHUqFGwsrLCqlWrREch0oh79+4hIiIC\n2dnZ2LZtGzw9Pev9M9LT0xEUFIRr167pxRZ0RM/KqJZPf/rpJxw/fhwLFy4UHYVIY+zs7LB7926M\nHTsW3bt3x/r16+v9JJxFixZh8uTJLEQyOEYzKd6/fx8+Pj7Yvn07evbsKToOkVZcunQJMpkM7dq1\nw9q1a+tlg4qbN29CKpUiIyMDtra29ZCSSHcYzaT4/vvv45133mEhklHp0KED4uPj4ejoCKlUimPH\njtX5PZcuXYqIiAgWIhkko5gUd+3ahZkzZ+LixYvCd/8gEmX//v0YO3Ysxo0bh08++QQNGjSo9Xvk\n5eWhXbt2uHTpElq0aKGBlERiGXwp3r17Fz4+Pti5cyeCgoJExyESKjc3F6NGjUJxcTFiYmLQunXr\nal8nl8uRlZWFoqIilJWVwczMDFZWVoiOjsa9e/ewevVqLScn0g6DL8WhQ4fCzc2N1yQSVVAqlfjy\nyy+xcOFCREVFVdmNJj8/HxkZGcjLy6t87WMmJiaQy+Vo1qwZfH190bRpU61nJ9I0gy7FHTt24LPP\nPsOFCxfQqFEj0XGIdMqFCxcgk8kQGBiIlStX4v79+0hLS4NCoXjq90okEnh5eaFVq1aaD0qkRQZb\ninfu3IGvry/27NmDrl27io5DpJOKi4vx73//G6WlpRg6dGiV54qKihAVFYWEhARYW1tj5MiR6NWr\nV+XzLEYyRAZZiiqVCkOGDIGuEnNrAAAVE0lEQVSnpycWLFggOg6RTsvPz8exY8fU7oW4ZMkSKJVK\nvP/++7h27Ro+//xzLF68GC1btqx8jUQiQWBgIJdSyWAY5CUZ27dvR3p6OubMmSM6CpHOy8jIUCvE\nR48e4eTJkwgLC0OjRo3QoUMHdOnSBUeOHKnyOoVCgYyMDG3GJdIogyvF3NxcTJs2DZs2bYK5ubno\nOEQ6TS6XV55U83c5OTkwNTWFs7Nz5WOtW7fGzZs31V6bl5cHuVyu0ZxE2mJQpahSqTBu3DiMGzcO\n/v7+ouMQ6bysrKxqH3/06JHaNb2WlpZ4+PBhta/Pzs6u92xEItT+6l0dFh0djZs3b+KHH34QHYVI\nLxQWFla57OIxCwsLlJSUVHmspKSk2rO4lUolCgsLNZaRSJsMphSzs7PxwQcf4NChQxq9XQ6RvlCp\nVHjw4AGysrIqf928ebPK1yNHjqx2VcXZ2RlKpRK3bt2Ck5MTAOD69etwc3Or9rPKyso0+rMQaYtB\nlKJKpcK7776LKVOmQCqVio5DpBV//fVXjYWXlZWFBg0awNXVtfKXm5sb+vXrV/l1fn4+bt++rfbe\nFhYWCAwMRExMDKZMmYJr167hzJkzWLx4cbVZzMzMNP3jEmmFQVySsX79enz11Vc4c+YM/3CSQSgt\nLUVOTk6NpVdSUlKl8B6X3t+/tra2rvFzMjIykJ6eXu0SalFRESIjI5GYmAgrKyuMGjWqynWKj5ma\nmsLT0xPu7u719eMTCaP3pXjz5k107twZhw8fhre3t+g4RE+lVCpx+/btGgvv3r17aNGiRY2FZ2dn\np3YpRW3J5XLExsZWW4rPytTUFMHBwTzbmwyCXi+fqlQqREREYNq0aSxE0gmPj+P9s+T+/vWtW7dg\nY2OjNuUFBgZW/neLFi0gkUg0ntfc3BwODg7VLqE+C6VSCUtLSxYiGQy9LsW1a9eioKAAM2bMEB2F\njERxcbHacbt/FqCZmZnaVDdgwIDK/3ZxcdGpEvHw8MDdu3efac/T6syYMQMTJ07E6NGj6zy5Eomm\nt8un169fR0BAAOLi4tC+fXvRccgAlJaWIjs7u8bSe/To0VOP41lZWYn+UWotMzPzmTcDf+zx3qfF\nxcWQyWTw8vLC2rVrYWNjo8GkRJqll6WoVCoRHByMgQMH4sMPPxQdh/SAQqHAnTt3alzWvH//Plq0\naKFWcn//ulmzZgY7DdWmGP+5GfjDhw8xc+ZM7NmzBzExMejevbuG0xJphl6W4qpVqxATE4Pjx49r\n5bgL6TaVSoU///yzxsLLzc2FjY2NWuH9vfQcHR2N/vdTTfdTNDX93wZYDg4O8PDwqHYT8H379mHs\n2LEYP348PvnkEzRooNdHaMgI6V0pZmRk4KWXXsKJEyfg6ekpOg5pwePjeDWVnrm5ebVLmY+/dnZ2\n1qnjeLpOLpcjOzsbhYWFKCsrg5mZGaytrZ/peGhubi5GjRqFv/76CzExMby1FOkVvSpFpVKJl19+\nGW+99RamTZsmOg7VA7lcjpycnGovPH/8mFwur7HwXF1d8cILL4j+UehvlEolli9fjsWLFyMqKgrD\nhw8XHYnomehVKX755ZfYtWsXjh49avTLXPpAoVBUXo/3pCnvwYMH1R7H+3vp2draGuxxPEN3/vx5\nyGQyBAUFISoqSi9PQiLjojelePnyZQQFBeH06dPw8PAQHcfoqVQq3L9/v8bCu337NmxtbWuc8po3\nb85/4Bi44uJiTJ06FXFxcdi2bRsCAgJERyJ6Ir0oRYVCge7duyM0NBTvvfee6DhGoaio6InX4d28\neRPZ2dmwsLCotvAeP+bs7MzN2anSjh078N577+GDDz7ABx98UHniDpEu0YtSXLJkCQ4cOIDY2Fj+\nQaoHj0+iqGnKKysre+J1eG5ubnBxceFxPKq1GzduIDQ0FBYWFoiOjq68AweRrtD5UkxNTUXPnj1x\n9uxZtG7dWnQcnadQKJCbm/vEuybcvHkT+fn5cHJyqnFZ08bGhsfxSCPKy8sxf/58rF69GuvWrcMb\nb7whOhJRJZ0uxfLycnTr1g1jxozBhAkTRMcRTqVS4d69ezUua96+fRt2dnY1TnnNmzfnxE3CnThx\nAqGhoXjttdewdOnSam9gTKRtQktRLpcjKysLRUVFlddCWVlZwdXVFebm5pg/fz6OHDmCgwcPGsXU\nUlhY+NTjeI0bN66x8JycnHgcj/RGfn4+JkyYgJSUFGzfvp0b+5NwQkrxWXbNMDc3x6xZs7Bz584n\n3u1bnzx69KjKvprVlZ5CoahxT01XV1dYWlqK/lGI6pVKpcLmzZvx4Ycf4rPPPsPkyZON4h/BpJu0\nXorPur+iUqmEiYkJvL29dX5HjPLy8srjeE86lpefnw9nZ+cap7ymTZvyLwMyWleuXEFISAgcHR2x\nYcMG2Nvbi45ERkirpZiZmYnU1NRa3dD0nxsPa5tKpcLdu3drvFXQ7du3YW9vX+OEx+N4RE9XWlqK\nTz75BFu3bsWmTZvQt29f0ZHIyGitFPPz83Hq1KlqJ8Rly5bh4sWLePToEWxsbPDWW2+hf//+lc9L\nJBIEBgZWuwFxXRUUFNRYeNnZ2bC0tKxxWdPZ2RlmZmb1no3IWP32228IDw+HTCbDvHnzeJyctEZr\npXju3Lkn3t37xo0bcHJygpmZGbKysjB79mx89tlnVXaucXR0hL+/f60+8+HDhzUex8vKyoJCoahx\nT01XV1c0bty4Tj87EdXevXv3MGbMGNy6dQvbtm1D27ZtRUciI6CV+7rI5fLKk2qq07Jly8r/NjEx\ngYmJCXJzc6uUYl5eHuRyeeUO/eXl5bh161aNU15BQQFcXFyqFFynTp0wePDgyq95HI9IN9nZ2WHP\nnj1Ys2YNgoKCsHjxYoSHh/PPK2mUVibFjIwMpKen13gscfXq1YiNjUVpaSnatGmDhQsXVrluSaFQ\nIDExEb/++iuysrJw584dODg41DjhOTg48DgekQFISUmBTCZD+/btsXbtWo0cSiECtFSKCQkJyMnJ\neerrFAoF/vjjD6SkpGDIkCFqNygtKyurPKHl8XIrERmHhw8fYsaMGdi7dy+2bt2K7t27i45EBkgr\nY1RZWdkzvU4ikaBDhw64d+8eDhw4oPa8s7MzunfvjpYtW7IQiYxMo0aNsHLlSqxcuRJvv/025syZ\ng/LyctGxyMBopRRrW2CP78NX1/chIsMzaNAgXLhwASdOnECvXr1w48YN0ZHIgGilFK2srJ54bC8/\nPx9xcXF4+PAhFAoFLly4gLi4OPj4+FQNamoKa2trbcQlIh3n5OSEX3/9FYMHD0ZAQAC+++470ZHI\nQGjlmKJcLkdsbGy1J9oUFBRgwYIFyMzMhFKphIODAwYNGlTlOkXgf6UYHBxcefYpERHwv8u9QkJC\nEBQUhJUrV/KWZlQnOnGd4rN4nusUicg4FBcX4/3338fx48exbds2/l1Bz01r1yt4eHhAIpE81/dK\nJJIq1ywSEf3dCy+8gA0bNmDu3LkYOHAgFi9eXKvtJIke0/rep8+yGfjfid77lIj0S2ZmJsLCwmBh\nYYHo6Gg4OTmJjkR6RKtXtrdq1QpeXl7PPDGyEImotlq1aoWjR4+iR48e6NSpE/bu3Ss6EukRnb2f\nooODAzw8PLhzBRE9t+PHjyMsLAyvv/46lixZUmWXLKLqCCnFx+RyObKzs1FYWIiysjKYmZnB2toa\nLi4uPMuUiOpFfn4+xo8fj9TUVGzfvh0dO3YUHYl0mNBSJCLSBpVKhU2bNmHGjBmYM2cOJk2axI3F\nqVosRSIyGunp6QgJCYGTkxM2bNgAOzs70ZFIx/AWEkRkNNq2bYuTJ0+iXbt2kEql+O2330RHIh3D\nSZGIjNJvv/2GUaNGITQ0FF988QUaNmwoOhLpAE6KRGSU+vTpg8TERKSlpaFbt25IT08XHYl0AEuR\niIyWvb09fvrpJ4wePRrdunXDxo0bwcUz48blUyIiAMnJyZDJZOjYsSO+/vprXiNtpDgpEhEB8Pb2\nxtmzZ2FnZwepVIoTJ06IjkQCcFIkIvqHn376CePGjcPEiRPx8ccfo0GDBqIjkZawFImIqnHr1i2M\nHDkSjx49QkxMDFq2bCk6EmkBl0+JiKrh5OSEgwcPYvDgwQgICMCOHTtERyIt4KRIRPQU586dg0wm\nQ48ePRAVFYUXXnhBdCTSEE6KRERP4e/vj4SEBABAp06dcP78ecGJSFMkc+bMmSM6BBGRrmvYsCEG\nDx4Me3t7hIWFAQACAwO5sbiB4fIpEVEtZWZmIjQ0FI0bN0Z0dDRatGghOhLVEy6fEhHVUqtWrfD7\n778jKCgIfn5+2Lt3r+hIVE84KRIR1cHx48cRFhaGQYMGYfHixWjUqJHoSFQHnBSJiOqge/fuSEhI\nwJ07d9ClSxekpKSIjkR1wFIkIqojGxsbfPfdd5g+fTpeeeUVrF69mhuL6ykunxIR1aP09HTIZDI4\nOztjw4YNsLOzEx2JaoGTIhFRPWrbti1OnTqFdu3aQSqVIjY2VnQkqgVOikREGnLo0CGEh4cjLCwM\nc+fORcOGDUVHoqfgpEhEpCF9+/ZFYmIiUlNTERQUhCtXroiORE/BUiQi0iB7e3v89NNPGDVqFLp1\n64ZNmzbxJBwdxuVTIiItSU5Ohkwmg7e3N9asWYOmTZuKjkT/wEmRiEhLvL29cfbsWdja2kIqleLE\niROiI9E/cFIkIhLgp59+wrhx4zBp0iTMnj0bDRo0EB2JwFIkIhImJycHI0eORGlpKbZu3YqWLVuK\njmT0uHxKRCSIs7MzDh06hEGDBiEgIAA7duwQHcnocVIkItIBZ8+eRUhICHr27InIyEi88MILoiMZ\nJU6KREQ6ICAgABcuXIBSqUTnzp1x/vx50ZGMEkuRiEhHWFlZYePGjfjPf/6DAQMGYOnSpVAqlaJj\nGRUunxIR6aDMzEyEhobC0tISmzdvRosWLURHMgqcFImIdFCrVq3w+++/IzAwEH5+fti3b5/oSEaB\nkyIRkY47duwYwsLC8MYbb2DJkiWwsLAQHclgcVIkItJxPXr0QGJiIu7cuYOAgABcunRJdCSDxVIk\nItIDNjY2+O677zBt2jT06tULq1ev5sbiGsDlUyIiPXP58mWEhITAxcUF69evh52dnehIBoOTIhGR\nnvH09MTJkyfRtm1bSKVSxMbGio5kMDgpEhHpsYMHD2L06NEYMWIEPv/8czRs2FB0JL3GSZGISI/1\n69cPCQkJSElJQVBQEK5cuSI6kl5jKRIR6TkHBwfs3bsXI0eORLdu3bB582aehPOcuHxKRGRAkpKS\nIJPJ4OPjg6+//hpNmjQRHUmvcFIkIjIgPj4+OHfuHGxsbCCVSnHy5EnRkfQKJ0UiIgO1Z88ejBs3\nDpMnT8bHH38MiUQiOpLOYykSERmwnJwcjBw5EmVlZdi6dSvc3NxER9JpXD4lIjJgzs7OOHjwIF57\n7TX4+/tj586doiPpNE6KRERG4uzZs5DJZOjVqxciIyNhaWkpOpLO4aRIRGQkAgICkJCQgPLycnTq\n1AkXLlwQHUnnsBSJiIyIlZUVNm3ahDlz5qB///5YtmwZlEql6Fg6g8unRERG6vr16wgNDYWVlRU2\nb94MR0dH0ZGE46RIRGSkWrdujbi4OHTt2hV+fn7Yv3+/6EjCcVIkIiLExcVhxIgRGDx4MBYvXgwL\nCwvRkYTgpEhEROjZsycSExORm5uLLl264NKlS6IjCcFSJCIiAICNjQ127NiBqVOn4uWXX8aaNWuM\nbmNxLp8SEZGay5cvQyaTwc3NDd9++y3s7OxER9IKTopERKTG09MTp06dgoeHB6RSKQ4fPiw6klZw\nUiQiohodPHgQo0ePxsiRI/H555/DzMxMdCSN4aRIREQ16tevHxISEpCUlISgoCBkZGSIjqQxLEUi\nInoqBwcH7Nu3DyNGjEBgYCCio6MN8iQcLp8SEVGtJCUlQSaTwdfXF2vWrEGTJk1ER6o3nBSJiKhW\nfHx8cPbsWTRp0gR+fn44deqU6Ej1hpMiERE9t927d2P8+PF47733MHv2bEgkEtGR6oSlSEREdZKT\nk4MRI0agvLwcW7duhZubm+hIz43Lp0REVCfOzs44dOgQBg4cCH9/f+zcuVN0pOfGSZGIiOpNfHw8\nQkJC8Morr2DFihWwtLR84mvlcjmysrJQVFSEsrIymJmZwcrKCq6urjA3N9di6v+PpUhERPWqqKgI\n7733Hk6fPo3t27ejU6dOVZ7Pz89HRkYG8vLyAKDKTY5NTf+3gOng4AAPDw80bdpUe8EBSObMmTNH\nq59IREQGzdzcHG+++SaaNWuG0NBQNGjQAF27doWJiQkyMzORkJCAoqIiqFQqtWsdHz9WXFyMnJwc\nmJmZabUYOSkSEZHGXL9+HSEhIbC2tsaSJUuQnZ1dZTJ8GolEAi8vL7Rq1UpzIf+GJ9oQEZHGtG7d\nGnFxcejduzeuXLnyxEK8desW3nrrLSxbtqzK4wqFAmlpacjPz9dGXJYiERFplpmZGYKDg2s8eWbN\nmjV48cUXq31OoVBobb9VliIREWmUXC6vPKmmOnFxcbC0tISvr+8TX5OXlwe5XK6JeFWwFImISKOy\nsrKe+FxJSQliYmIQERHx1PfJzs6uz1jVYikSEZFGFRUVPfFY4tatW9G3b1/Y29vX+B5KpRKFhYWa\niFcFS5GIiDSqrKys2sevXbuGxMREDB48uE7vU58aaPwTiIjIqJmZmVX7eHJyMvLy8jBmzBgAwKNH\nj6BUKnHz5k1ERkY+8/vUJ5YiERFplJWVFUxNTdWWUPv374+ePXtWfr1r1y7k5eVh0qRJau9hamoK\na2trjWdlKRIRkUa5uroiPT1d7XELCwtYWFhUft2oUSM0bNjwiTctdnFx0VjGx7ijDRERady5c+dw\n+/bt5/5+R0dH+Pv712Oi6vFEGyIi0jgPD4/nvgGxRCKBh4dHPSeqHkuRiIg0rmnTpvDy8qp1MT7e\n+1Rbm4KzFImISCtatWpVq2LU9mbgAI8pEhGRluny/RRZikREJIRcLkd2djYKCwtRVlYGMzMzWFtb\nw8XFpcbNwzWJpUhERFSBxxSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgq\nsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJiIgqsBSJ\niIgq/D/kBUCHYTZD8AAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10a0eb6d8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import networkx as nx\n",
"nx_G = G.to_networkx().to_undirected()\n",
"pos = nx.kamada_kawai_layout(nx_G)\n",
"nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"def gcn_message(edges):\n",
" print(\"gcn_message:\",edges.src['h'])\n",
" return {'msg' : edges.src['h']}\n",
"\n",
"def gcn_reduce(nodes):\n",
" print(\"gcn_reduce:\",nodes.mailbox['msg'])\n",
" return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}\n",
"\n",
"class GCNLayer(nn.Module):\n",
" def __init__(self, in_feats, out_feats):\n",
" super(GCNLayer, self).__init__()\n",
" self.linear = nn.Linear(in_feats, out_feats)\n",
"\n",
" def forward(self, g, inputs):\n",
" g.ndata['h'] = inputs\n",
" g.send(g.edges(), gcn_message)\n",
" g.recv(g.nodes(), gcn_reduce)\n",
" h = g.ndata.pop('h')\n",
" return self.linear(h)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class GCN(nn.Module):\n",
" def __init__(self, in_feats, num_classes):\n",
" super(GCN, self).__init__()\n",
" self.gcn1 = GCNLayer(in_feats, num_classes)\n",
"\n",
" def forward(self, g, inputs):\n",
" h = self.gcn1(g, inputs)\n",
" h = torch.relu(h)\n",
"\n",
" return h"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"net = GCN(5, 2)\n",
"inputs = torch.eye(5)\n",
"labeled_nodes = torch.tensor([0, 1, 2, 3, 4])\n",
"labels = torch.tensor([0, 0, 1, 1, 0])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gcn_message: tensor([[1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.]])\n",
"gcn_reduce: tensor([[[0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.]],\n",
"\n",
" [[1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.]]])\n",
"Epoch 0 | Loss: 0.6980\n",
"gcn_message: tensor([[1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.]])\n",
"gcn_reduce: tensor([[[0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.]],\n",
"\n",
" [[1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.]]])\n",
"Epoch 1 | Loss: 0.6946\n",
"gcn_message: tensor([[1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.]])\n",
"gcn_reduce: tensor([[[0., 1., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 1., 0.],\n",
" [0., 0., 0., 0., 1.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.]],\n",
"\n",
" [[1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0.]]])\n",
"gcn_reduce: tensor([[[1., 0., 0., 0., 0.]]])\n",
"Epoch 2 | Loss: 0.6917\n"
]
}
],
"source": [
"optimizer = torch.optim.Adam(net.parameters(), lr=0.01)\n",
"all_logits = []\n",
"G.set_n_initializer(dgl.init.zero_initializer)\n",
"\n",
"for epoch in range(3): \n",
" logits = net(G, inputs)\n",
" all_logits.append(logits.detach())\n",
" logp = F.log_softmax(logits, 1)\n",
" loss = F.nll_loss(logp[labeled_nodes], labels)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment