Skip to content

Instantly share code, notes, and snippets.

@pbloem
Created April 30, 2020 01:43
Show Gist options
  • Save pbloem/f8667c1bfc75587bb054c080b3d6a988 to your computer and use it in GitHub Desktop.
Save pbloem/f8667c1bfc75587bb054c080b3d6a988 to your computer and use it in GitHub Desktop.
RGCN implementation from scratch. Untested in gist form. Let me know if you need this for something.
import torch, os, sys
from torch import nn
import torch.nn.functional as F
import torch.distributions as ds
from math import sqrt, ceil
import layers, util
import torch as T
class RGCNClassic(nn.Module):
"""
Classic RGCN
"""
def __init__(self, edges, n, numcls, emb=16, bases=None, softmax=False):
super().__init__()
self.emb = emb
self.bases = bases
self.numcls = numcls
self.softmax = softmax
# horizontally and vertically stacked versions of the adjacency graph
hor_ind, hor_size = adj(edges, n, vertical=False)
ver_ind, ver_size = adj(edges, n, vertical=True)
_, rn = hor_size
r = rn//n
t = len(edges[0][0])
vals = torch.ones(ver_ind.size(0), dtype=torch.float)
vals = vals / util.sum_sparse(ver_ind, vals, ver_size) # row-normalize
# -- the values are the same for the horizontal and the vertically stacked adjacency matrices
# so we can just normalize them by the vertically stacked one and reuse for the horizontal
hor_graph = torch.sparse.FloatTensor(indices=hor_ind.t(), values=vals, size=hor_size)
self.register_buffer('hor_graph', hor_graph)
ver_graph = torch.sparse.FloatTensor(indices=ver_ind.t(), values=vals, size=ver_size)
self.register_buffer('ver_graph', ver_graph)
# layer 1 weights
if bases is None:
self.weights1 = nn.Parameter(torch.FloatTensor(r, n, emb))
nn.init.xavier_uniform_(self.weights1, gain=nn.init.calculate_gain('relu'))
self.bases1 = None
else:
self.comps1 = nn.Parameter(torch.FloatTensor(r, bases))
nn.init.xavier_uniform_(self.comps1, gain=nn.init.calculate_gain('relu'))
self.bases1 = nn.Parameter(torch.FloatTensor(bases, n, emb))
nn.init.xavier_uniform_(self.bases1, gain=nn.init.calculate_gain('relu'))
# layer 2 weights
if bases is None:
self.weights2 = nn.Parameter(torch.FloatTensor(r, emb, numcls) )
nn.init.xavier_uniform_(self.weights2, gain=nn.init.calculate_gain('relu'))
self.bases2 = None
else:
self.comps2 = nn.Parameter(torch.FloatTensor(r, bases))
nn.init.xavier_uniform_(self.comps2, gain=nn.init.calculate_gain('relu'))
self.bases2 = nn.Parameter(torch.FloatTensor(bases, emb, numcls))
nn.init.xavier_uniform_(self.bases2, gain=nn.init.calculate_gain('relu'))
self.bias1 = nn.Parameter(torch.FloatTensor(emb).zero_())
self.bias2 = nn.Parameter(torch.FloatTensor(numcls).zero_())
def forward(self):
## Layer 1
n, rn = self.hor_graph.size()
r = rn // n
e = self.emb
b, c = self.bases, self.numcls
if self.bases1 is not None:
# weights = torch.einsum('rb, bij -> rij', self.comps1, self.bases1)
weights = torch.mm(self.comps1, self.bases1.view(b, n*e)).view(r, n, e)
else:
weights = self.weights1
assert weights.size() == (r, n, e)
# Apply weights and sum over relations
h = torch.mm(self.hor_graph, weights.view(r*n, e))
assert h.size() == (n, e)
h = F.relu(h + self.bias1)
## Layer 2
# Multiply adjacencies by hidden
h = torch.mm(self.ver_graph, h) # sparse mm
h = h.view(r, n, e) # new dim for the relations
if self.bases2 is not None:
# weights = torch.einsum('rb, bij -> rij', self.comps2, self.bases2)
weights = torch.mm(self.comps2, self.bases2.view(b, e * c)).view(r, e, c)
else:
weights = self.weights2
# Apply weights, sum over relations
# h = torch.einsum('rhc, rnh -> nc', weights, h)
h = torch.bmm(h, weights).sum(dim=0)
assert h.size() == (n, c)
if self.softmax:
return F.softmax(h + self.bias2, dim=1)
return h + self.bias2 #-- softmax is applied in the loss
def adj(edges, num_nodes, cuda=False, vertical=True):
"""
Computes a sparse adjacency matrix for the given graph (the adjacency matrices of all
relations are stacked vertically).
:param edges: Dictionary representing the edges
:param i2r: list of relations
:param i2n: list of nodes
:return: sparse tensor
"""
ST = torch.cuda.sparse.FloatTensor if cuda else torch.sparse.FloatTensor
r, n = len(edges.keys()), num_nodes
size = (r*n, n) if vertical else (n, r*n)
from_indices = []
upto_indices = []
for rel, (fr, to) in edges.items():
offset = rel * n
if vertical:
fr = [offset + f for f in fr]
else:
to = [offset + t for t in to]
from_indices.extend(fr)
upto_indices.extend(to)
indices = torch.tensor([from_indices, upto_indices], dtype=torch.long, device=d(cuda))
assert indices.size(1) == sum([len(ed[0]) for _, ed in edges.items()])
assert indices[0, :].max() < size[0], f'{indices[0, :].max()}, {size}, {r}, {edges.keys()}'
assert indices[1, :].max() < size[1], f'{indices[1, :].max()}, {size}, {r}, {edges.keys()}'
return indices.t(), size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment