Skip to content

Instantly share code, notes, and snippets.

@ricsi98
Created May 5, 2023 20:22
Show Gist options
  • Save ricsi98/e0e1029844997b8c0ee089a00e9df38f to your computer and use it in GitHub Desktop.
Save ricsi98/e0e1029844997b8c0ee089a00e9df38f to your computer and use it in GitHub Desktop.
Poincaré embeddings for the Karateclub graph
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bb251a80-b977-409a-a1ee-b944a88fd30a",
"metadata": {},
"outputs": [],
"source": [
"import networkx\n",
"import geoopt\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ead8e198-07e7-45ce-a559-6c7cbb92a2ec",
"metadata": {},
"outputs": [],
"source": [
"class KarateClubDataset(torch.utils.data.Dataset):\n",
" \n",
" def __init__(self, num_negatives=1):\n",
" self.graph = networkx.karate_club_graph()\n",
" self.nneg = num_negatives\n",
" \n",
" def __len__(self):\n",
" return len(self.graph)\n",
" \n",
" def __getitem__(self, index):\n",
" node = index\n",
" ns = list(self.graph.neighbors(node))\n",
" neighbors = torch.tensor(ns, dtype=torch.long)\n",
" positives = torch.stack((torch.ones(len(neighbors)) * node, neighbors), dim=1)\n",
" negatives = list(set(self.graph.nodes()).difference(ns))\n",
" negatives = random.choices(negatives, k=len(neighbors) * self.nneg)\n",
" negatives = torch.tensor(negatives, dtype=torch.long)\n",
" negatives = torch.stack((torch.ones(len(negatives)) * node, negatives), dim=1)\n",
" plabels = torch.ones(len(positives), dtype=torch.long)\n",
" nlabels = torch.zeros(len(negatives), dtype=torch.long)\n",
" return torch.cat((positives, negatives), dim=0).long(), torch.cat((plabels, nlabels), dim=0).double()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8ecd1095-6d9c-48e3-971b-9997aa9d0d5f",
"metadata": {},
"outputs": [],
"source": [
"class PoincareEmbedding(nn.Module):\n",
" \n",
" def __init__(self, c, vocab_size):\n",
" super().__init__()\n",
" self._manifold = geoopt.manifolds.PoincareBall(c, learnable=False)\n",
" data = torch.tensor(np.random.normal(0, 0.45, size=(vocab_size, 2)), dtype=torch.double)\n",
" data = self._manifold.retr(data, torch.tensor([0,0], dtype=torch.double))\n",
" data = geoopt.ManifoldTensor(data, manifold=self._manifold)\n",
" self.w = geoopt.ManifoldParameter(data, requires_grad=True)\n",
"\n",
" def forward(self, ids):\n",
" s0 = ids.shape\n",
" ws = self.w[ids.view(-1)]\n",
" return ws.view(*s0, 2)\n",
" \n",
"\n",
"class Model(nn.Module):\n",
" \n",
" def __init__(self, embedding):\n",
" super().__init__()\n",
" self.embd = embedding\n",
" \n",
" def forward(self, a, b):\n",
" manifold = self.embd._manifold\n",
" va, vb = self.embd(a), self.embd(b)\n",
" return manifold.dist2(va, vb)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2d1c393e-2b6a-4de0-90ee-86c50bf3b81f",
"metadata": {},
"outputs": [],
"source": [
"ds = KarateClubDataset()\n",
"e = PoincareEmbedding(3, len(ds.graph))\n",
"m = Model(e)\n",
"opt = geoopt.optim.RiemannianSGD(m.parameters(), lr=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f0fdaf7d-9a1e-4e7d-b8b2-8f351b67fad6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 1.398: 100%|██████████████████████████| 1000/1000 [00:41<00:00, 24.25it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"hist = []\n",
"snapshots = []\n",
"MOD = 50\n",
"\n",
"def crit(d2, y):\n",
" return - (y * torch.log(torch.sigmoid(-d2)) + (1-y) * torch.log(torch.sigmoid(d2))).mean()\n",
" \n",
"E = 1000\n",
"for e in (pbar := tqdm(range(E), total=E)):\n",
" if e % MOD == 0:\n",
" snapshots.append(m.embd.w.data.numpy().copy())\n",
" bloss = []\n",
" for batch in range(len(ds)):\n",
" x, y = ds[batch]\n",
" xa, xb = x[:, 0], x[:, 1]\n",
" y_ = m(xa, xb)\n",
" loss = crit(y_, y)\n",
" bloss.append(loss.item())\n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
" hist.append(np.mean(bloss))\n",
" pbar.set_description(f\"loss: {np.mean(hist):.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6324d2fd-29ab-4141-8332-724b946b68f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f491dbc1d80>]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(hist)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cc34cc3f-3b9e-45a2-96fe-bffe4309c3b5",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.cluster import SpectralClustering\n",
"\n",
"N = 20\n",
"w = m.embd.w.data\n",
"manifold = m.embd._manifold\n",
"\n",
"def snapshot(w, manifold, graph, colors, savename=None):\n",
" edges = set()\n",
" for node in graph:\n",
" for neigh in graph.neighbors(node):\n",
" if (neigh, node) in edges: continue\n",
" edges.add((node, neigh))\n",
" a, b = w[node], w[neigh]\n",
" t = torch.arange(N)/N\n",
" points = manifold.geodesic(t.view(-1,1), a.repeat(N).view(N, 2), b.repeat(N).view(N,2))\n",
" plt.plot(points[:, 0], points[:, 1], c=\"gray\", alpha=0.3, linestyle=\"--\")\n",
" plt.scatter(w[:, 0], w[:, 1], c=colors)\n",
" plt.xlim(-1,1)\n",
" plt.ylim(-1,1)\n",
" if savename is not None:\n",
" plt.savefig(f\"./figs/{savename}.png\", dpi=250)\n",
" plt.clf()\n",
" plt.cla()\n",
" plt.close()\n",
" else:\n",
" plt.show()\n",
" \n",
" \n",
"A = networkx.adjacency_matrix(ds.graph).todense()\n",
"\n",
"sc = SpectralClustering(2, affinity='precomputed', n_init=100)\n",
"sc.fit(A)\n",
"colors = sc.labels_\n",
"\n",
"snapshot(torch.tensor(snapshots[-1]).double(), manifold, ds.graph, colors)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment