Created April 30, 2020 01:43
RGCN implementation from scratch. Untested in gist form. Let me know if you need this for something.
import torch, os, sys
from torch import nn
import torch.nn.functional as F
import torch.distributions as ds
from math import sqrt, ceil
import layers, util
import torch as T
class RGCNClassic(nn.Module):
Classic RGCN
def __init__(self, edges, n, numcls, emb=16, bases=None, softmax=False):
self.emb = emb
self.bases = bases
self.numcls = numcls
self.softmax = softmax
# horizontally and vertically stacked versions of the adjacency graph
hor_ind, hor_size = adj(edges, n, vertical=False)
ver_ind, ver_size = adj(edges, n, vertical=True)
_, rn = hor_size
r = rn//n
t = len(edges[0][0])
vals = torch.ones(ver_ind.size(0), dtype=torch.float)
vals = vals / util.sum_sparse(ver_ind, vals, ver_size) # row-normalize
# -- the values are the same for the horizontal and the vertically stacked adjacency matrices
# so we can just normalize them by the vertically stacked one and reuse for the horizontal
hor_graph = torch.sparse.FloatTensor(indices=hor_ind.t(), values=vals, size=hor_size)
self.register_buffer('hor_graph', hor_graph)
ver_graph = torch.sparse.FloatTensor(indices=ver_ind.t(), values=vals, size=ver_size)
self.register_buffer('ver_graph', ver_graph)
# layer 1 weights
if bases is None:
self.weights1 = nn.Parameter(torch.FloatTensor(r, n, emb))
nn.init.xavier_uniform_(self.weights1, gain=nn.init.calculate_gain('relu'))
self.bases1 = None
self.comps1 = nn.Parameter(torch.FloatTensor(r, bases))
nn.init.xavier_uniform_(self.comps1, gain=nn.init.calculate_gain('relu'))
self.bases1 = nn.Parameter(torch.FloatTensor(bases, n, emb))
nn.init.xavier_uniform_(self.bases1, gain=nn.init.calculate_gain('relu'))
# layer 2 weights
if bases is None:
self.weights2 = nn.Parameter(torch.FloatTensor(r, emb, numcls) )
nn.init.xavier_uniform_(self.weights2, gain=nn.init.calculate_gain('relu'))
self.bases2 = None
self.comps2 = nn.Parameter(torch.FloatTensor(r, bases))
nn.init.xavier_uniform_(self.comps2, gain=nn.init.calculate_gain('relu'))
self.bases2 = nn.Parameter(torch.FloatTensor(bases, emb, numcls))
nn.init.xavier_uniform_(self.bases2, gain=nn.init.calculate_gain('relu'))
self.bias1 = nn.Parameter(torch.FloatTensor(emb).zero_())
self.bias2 = nn.Parameter(torch.FloatTensor(numcls).zero_())
def forward(self):
## Layer 1
n, rn = self.hor_graph.size()
r = rn // n
e = self.emb
b, c = self.bases, self.numcls
if self.bases1 is not None:
# weights = torch.einsum('rb, bij -> rij', self.comps1, self.bases1)
weights =, self.bases1.view(b, n*e)).view(r, n, e)
weights = self.weights1
assert weights.size() == (r, n, e)
# Apply weights and sum over relations
h =, weights.view(r*n, e))
assert h.size() == (n, e)
h = F.relu(h + self.bias1)
## Layer 2
# Multiply adjacencies by hidden
h =, h) # sparse mm
h = h.view(r, n, e) # new dim for the relations
if self.bases2 is not None:
# weights = torch.einsum('rb, bij -> rij', self.comps2, self.bases2)
weights =, self.bases2.view(b, e * c)).view(r, e, c)
weights = self.weights2
# Apply weights, sum over relations
# h = torch.einsum('rhc, rnh -> nc', weights, h)
h = torch.bmm(h, weights).sum(dim=0)
assert h.size() == (n, c)
if self.softmax:
return F.softmax(h + self.bias2, dim=1)
return h + self.bias2 #-- softmax is applied in the loss
def adj(edges, num_nodes, cuda=False, vertical=True):
Computes a sparse adjacency matrix for the given graph (the adjacency matrices of all
relations are stacked vertically).
:param edges: Dictionary representing the edges
:param i2r: list of relations
:param i2n: list of nodes
:return: sparse tensor
ST = torch.cuda.sparse.FloatTensor if cuda else torch.sparse.FloatTensor
r, n = len(edges.keys()), num_nodes
size = (r*n, n) if vertical else (n, r*n)
from_indices = []
upto_indices = []
for rel, (fr, to) in edges.items():
offset = rel * n
if vertical:
fr = [offset + f for f in fr]
to = [offset + t for t in to]
indices = torch.tensor([from_indices, upto_indices], dtype=torch.long, device=d(cuda))
assert indices.size(1) == sum([len(ed[0]) for _, ed in edges.items()])
assert indices[0, :].max() < size[0], f'{indices[0, :].max()}, {size}, {r}, {edges.keys()}'
assert indices[1, :].max() < size[1], f'{indices[1, :].max()}, {size}, {r}, {edges.keys()}'
return indices.t(), size
