Skip to content

Instantly share code, notes, and snippets.

@ricsi98
Last active January 16, 2023 12:38
Show Gist options
  • Save ricsi98/98fe8066190ad578dc0e23c26e699f6f to your computer and use it in GitHub Desktop.
Save ricsi98/98fe8066190ad578dc0e23c26e699f6f to your computer and use it in GitHub Desktop.
Graph2Gauss in PyTorch
"""PyTorch implementation of the Graph2Gauss algorithm (https://arxiv.org/pdf/1707.03815.pdf)"""
import walker
import random
import numpy as np
import networkx as nx
import torch
from torch.nn import Linear, Module, Sequential, ReLU
elu = torch.nn.ELU()
def kl(mu1, mu2, sigma1, sigma2):
# kl divergence of two gaussian with diagonal covariance matrices
det1, det2 = torch.prod(sigma1, axis=1), torch.prod(sigma2, axis=1)
d = mu1.shape[1]
mu_diff = mu2 - mu1
return 0.5 * (torch.log(det2 / det1 + 1e-14) - d \
+ torch.sum(sigma1 / sigma2, axis=1) \
+ torch.sum(mu_diff / sigma2 * mu_diff, axis=1))
def g2g_loss(anchor, positive, negative):
mu, sigma = anchor
pmu, psigma = positive
nmu, nsigma = negative
epos = kl(mu, pmu, sigma, psigma)
eneg = kl(mu, nmu, sigma, nsigma)
return torch.mean(epos**2 + torch.exp(-eneg), axis=0)
def mlp(sizes):
layers = [
Sequential(Linear(l1, l2), ReLU()) for l1, l2 in zip(sizes[:-1], sizes[1:])
]
return Sequential(*layers)
class GaussianEncoder(Module):
def __init__(self, in_dim, hidden_dim, n_layers, embd_dim):
super().__init__()
if n_layers == 1:
self.ff = Sequential(Linear(in_dim, hidden_dim), ReLU())
else:
self.ff = mlp([in_dim] + [hidden_dim] * (n_layers-1))
self.mu = Linear(hidden_dim, embd_dim)
self.sigma = Linear(hidden_dim, embd_dim)
def forward(self, x):
h = self.ff(x)
mu, sigma = self.mu(h), self.sigma(h)
return mu, elu(sigma) + 1
class Sampler:
def __init__(self, graph, depth=5):
self.graph = graph
self.depth = depth
self._nodes = list(graph.nodes)
def sample(self, bs):
nodes = random.choices(self._nodes, k=bs)
walks = walker.random_walks(
self.graph,
n_walks=1,
walk_len=self.depth,
start_nodes=nodes,
p=100,
q=0.1,
verbose=False
)
walks = walks.astype(int)
anchors, positives, negatives = [], [], []
for i in range(0, self.depth-2):
for j in range(i+1, self.depth-1):
anchors.append(walks[:, i])
positives.append(walks[:, j])
negatives.append(walks[:, j+1])
anchors = np.concatenate(anchors, axis=0)
positives = np.concatenate(positives, axis=0)
negatives = np.concatenate(negatives, axis=0)
return anchors, positives, negatives
def train(graph, X, net, opt, depth, epochs, batch_size=64):
s = Sampler(graph, depth)
hist = []
for e in range(epochs):
rl = []
for _ in range(X.shape[0] // batch_size):
anc, pos, neg = s.sample(batch_size)
xa = net(X[anc])
xp = net(X[pos])
xn = net(X[neg])
loss = g2g_loss(xa, xp, xn)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(net.parameters(), 0.5, error_if_nonfinite=True)
opt.step()
rl.append(loss.item())
hist.append(np.mean(rl))
print(f"Epoch {e} loss {hist[-1]:.4f}")
return hist
def evaluate_roc_auc(graph, X, net):
true = nx.adj_matrix(graph).todense().flatten().tolist()[0]
pred = []
mu, sigma = net(X)
with torch.no_grad():
for i in tqdm(range(len(graph)), total=len(graph)):
rmu, rsig = torch.stack([mu[i]]*len(graph), axis=0), torch.stack([sigma[i]]*len(graph), axis=0)
pred = pred + (-kl(rmu, mu, rsig, sigma)).numpy().tolist()
return roc_auc_score(true, pred)
if __name__ == '__main__':
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx
from sklearn.metrics import roc_auc_score
import networkx as nx
from tqdm import tqdm
import numpy as np
ds = Planetoid("./data", "cora")
X = ds.data.x
g = to_networkx(ds.data)
net = GaussianEncoder(X.shape[1], 32, 2, 16)
opt = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)
print("Training...")
train(g, X, net, opt, 7, 50)
print("Evaluation in progress...")
print("ROC-AUC", evaluate_roc_auc(g, X, net))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment