Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created December 3, 2020 09:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save yzh119/cc1ab6ad284af8b219bc94568c2f37e6 to your computer and use it in GitHub Desktop.
Save yzh119/cc1ab6ad284af8b219bc94568c2f37e6 to your computer and use it in GitHub Desktop.
Efficient Sparse Transformer implementation with DGL's builtin operators
import dgl
import dgl.ops as ops
import numpy as np
import torch as th
import torch.nn as nn
class FFN(nn.Module):
def __init__(self, d_feat, d_ffn, dropout=0.1):
super().__init__()
self.linear_0 = nn.Linear(d_feat, d_ffn)
self.linear_1 = nn.Linear(d_ffn, d_feat)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear_1(self.dropout(th.relu(self.linear_0(x))))
class TransLayer(nn.Module):
def __init__(self,
d_feat=512,
h=8,
d_qkv=64,
d_ffn=2048,
dropout_h=0.1,
dropout_a=0.1):
"""
Parameters
----------
d_feat : int
The feature dimension, defaults to 512.
h : int
The number of heads, defaults to 8.
d_qkv : int
The query/key/value dimension per head, defaults to 64.
d_ffn : int
The inner feature dimension in FFN layer, defaults to 2048.
dropout_h : float
The dropout rate on feature, defaults to 0.1.
dropout_a : float
The dropout rate on attention value, defaults to 0.1.
"""
super().__init__()
self.d_feat = d_feat
self.h = h
self.d_qkv = d_qkv
self.d_ffn = d_ffn
self.proj_q = nn.Linear(d_feat, h * d_qkv, bias=False)
self.proj_k = nn.Linear(d_feat, h * d_qkv, bias=False)
self.proj_v = nn.Linear(d_feat, h * d_qkv, bias=False)
self.proj_o = nn.Linear(h * d_qkv, d_feat, bias=False)
self.norm_in = nn.LayerNorm(self.d_feat)
self.norm_inter = nn.LayerNorm(self.d_feat)
self.drop_h = nn.Dropout(dropout_h)
self.drop_att = nn.Dropout(dropout_a)
self.ffn = FFN(d_feat, d_ffn)
def forward(self, g, h):
"""Forward function of Transformer on a DGLGraph.
Parameters
----------
g : DGLGraph
The graph to apply Transformer self-attention on.
h : Tensor
Input node feature.
"""
q = self.proj_q(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv)
k = self.proj_k(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv)
v = self.proj_v(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv)
e = ops.u_dot_v(g, k, q) # (E, h, 1): dot product of query and key
# Notice: we can also implement relative positional encoding efficiently with
# operators such as u_dot_e, e_dot_v
a = self.drop_att(ops.edge_softmax(g, e / np.sqrt(self.d_qkv))) # (E, h, 1): attention score
wv = ops.u_mul_e_sum(g, v, a) # (N, h, d_qkv): weighted sum of value by attention score
o = self.drop_h(self.proj_o(wv.view(-1, self.h * self.d_qkv))) # (N, d_feat): output
h = self.norm_in(h + o)
h = self.norm_inter(h + self.ffn(h))
return h
if __name__ == "__main__":
layer = TransLayer()
g = dgl.rand_graph(30, 100)
h = th.rand(30, 512)
# g = g.to(0)
# h = h.to(0)
print(layer(g, h))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment