Skip to content

Instantly share code, notes, and snippets.

@nashid
Forked from yzh119/dgl-transformer.py
Created March 12, 2022 00:50
Show Gist options
  • Save nashid/f4b66fdd894aefe8542c3e22e3029d6a to your computer and use it in GitHub Desktop.
Save nashid/f4b66fdd894aefe8542c3e22e3029d6a to your computer and use it in GitHub Desktop.
Efficient Sparse Transformer implementation with DGL's builtin operators
import dgl
import dgl.ops as ops
import numpy as np
import torch as th
import torch.nn as nn
class FFN(nn.Module):
def __init__(self, d_feat, d_ffn, dropout=0.1):
super().__init__()
self.linear_0 = nn.Linear(d_feat, d_ffn)
self.linear_1 = nn.Linear(d_ffn, d_feat)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear_1(self.dropout(th.relu(self.linear_0(x))))
class TransLayer(nn.Module):
def __init__(self,
d_feat=512,
h=8,
d_qkv=64,
d_ffn=2048,
dropout_h=0.1,
dropout_a=0.1):
"""
Parameters
----------
d_feat : int
The feature dimension, defaults to 512.
h : int
The number of heads, defaults to 8.
d_qkv : int
The query/key/value dimension per head, defaults to 64.
d_ffn : int
The inner feature dimension in FFN layer, defaults to 2048.
dropout_h : float
The dropout rate on feature, defaults to 0.1.
dropout_a : float
The dropout rate on attention value, defaults to 0.1.
"""
super().__init__()
self.d_feat = d_feat
self.h = h
self.d_qkv = d_qkv
self.d_ffn = d_ffn
self.proj_q = nn.Linear(d_feat, h * d_qkv, bias=False)
self.proj_k = nn.Linear(d_feat, h * d_qkv, bias=False)
self.proj_v = nn.Linear(d_feat, h * d_qkv, bias=False)
self.proj_o = nn.Linear(h * d_qkv, d_feat, bias=False)
self.norm_in = nn.LayerNorm(self.d_feat)
self.norm_inter = nn.LayerNorm(self.d_feat)
self.drop_h = nn.Dropout(dropout_h)
self.drop_att = nn.Dropout(dropout_a)
self.ffn = FFN(d_feat, d_ffn)
def forward(self, g, h):
"""Forward function of Transformer on a DGLGraph.
Parameters
----------
g : DGLGraph
The graph to apply Transformer self-attention on.
h : Tensor
Input node feature.
"""
q = self.proj_q(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv)
k = self.proj_k(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv)
v = self.proj_v(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv)
e = ops.u_dot_v(g, k, q) # (E, h, 1): dot product of query and key
# Notice: we can also implement relative positional encoding efficiently with
# operators such as u_dot_e, e_dot_v
a = self.drop_att(ops.edge_softmax(g, e / np.sqrt(self.d_qkv))) # (E, h, 1): attention score
wv = ops.u_mul_e_sum(g, v, a) # (N, h, d_qkv): weighted sum of value by attention score
o = self.drop_h(self.proj_o(wv.view(-1, self.h * self.d_qkv))) # (N, d_feat): output
h = self.norm_in(h + o)
h = self.norm_inter(h + self.ffn(h))
return h
if __name__ == "__main__":
layer = TransLayer()
g = dgl.rand_graph(30, 100)
h = th.rand(30, 512)
# g = g.to(0)
# h = h.to(0)
print(layer(g, h))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment