Skip to content

Instantly share code, notes, and snippets.

@hengck23
Last active February 25, 2022 11:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save hengck23/d3eb40d9b5bae7d08d3e12f26c84b0d7 to your computer and use it in GitHub Desktop.
Save hengck23/d3eb40d9b5bae7d08d3e12f26c84b0d7 to your computer and use it in GitHub Desktop.
transformer fast decoder
from common import *
# https://scale.com/blog/pytorch-improvements
# Making Pytorch Transformer Twice as Fast on Sequence Generation.
# https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim=2048, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(dim, ff_dim)
self.linear2 = nn.Linear(ff_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout(F.relu(self.linear1(x)))
x = self.linear2(x)
return x
#layer normalization
class Norm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.alpha = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.eps = eps
def forward(self, x):
#return x
z = (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps)
x = self.alpha*z + self.bias
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_head, dropout=0.1):
super().__init__()
self.dim = dim
self.d_k = dim // num_head
self.num_head = num_head
self.dropout = dropout
self.q = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.out = nn.Linear(dim, dim)
def attention(self, q, k, v, mask):
score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # torch.Size([8, 4, 10, 10]) = batch_size, num_head, LqxLk
score = torch.clamp_min(score,-6e3)
if mask is not None:
mask = mask.unsqueeze(1)
#print(score.min())
#score = score.masked_fill(mask == 0, -6e6) #-65504
score = score.masked_fill(mask == 0, -6e4) #-65504
#score = score.masked_fill(mask == 0, -half('inf'))
# https://github.com/NVIDIA/apex/issues/93
# How to use fp16 training with masked operations
score = F.softmax(score, dim=-1)
if self.dropout > 0:
score = F.dropout(score, self.dropout, training=self.training)
value = torch.matmul(score, v)
return value
def forward(self, q, k, v, mask=None):
batch_size, T, dim = q.shape
# perform linear operation and split into h heads
k = self.k(k).reshape(batch_size, -1, self.num_head, self.d_k)
q = self.q(q).reshape(batch_size, -1, self.num_head, self.d_k)
v = self.v(v).reshape(batch_size, -1, self.num_head, self.d_k)
# transpose to get dimensions batch_size * num_head * T * d_k
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# calculate attention using function we will define next
value = self.attention(q, k, v, mask)
# concatenate heads and put through final linear layer
value = value.transpose(1, 2).contiguous().reshape(batch_size, -1, self.dim)
value = self.out(value)
return value
#---
class TransformerEncodeLayer(nn.Module):
def __init__(self, dim, ff_dim, num_head, dropout=0.1):
super().__init__()
self.norm1 = Norm(dim)
self.norm2 = Norm(dim)
self.attn = MultiHeadAttention(dim, num_head, dropout=0.1)
self.ff = FeedForward(dim, ff_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, x_mask):
x1 = self.attn(x, x1, x1, x_mask) #self-attention
x1 = x + self.dropout1(x1)
x = self.norm1(x1)
x2 = self.ff(x)
x2 = x + self.dropout2(x2)
x = self.norm2(x2)
return x
class TransformerEncode(nn.Module):
def __init__(self, dim, ff_dim, num_head, num_layer):
super().__init__()
self.num_layer = num_layer
self.layer = nn.ModuleList([
TransformerEncodeLayer(dim, ff_dim, num_head) for i in range(num_layer)
])
self.norm = Norm(dim)
def forward(self, x, x_mask):
for i in range(self.num_layer):
x = self.layer[i](x, x_mask)
return x
#---
class TransformerDecodeLayer(nn.Module):
def __init__(self, dim, ff_dim, num_head, dropout=0.1):
super().__init__()
self.norm1 = Norm(dim)
self.norm2 = Norm(dim)
self.norm3 = Norm(dim)
self.attn1 = MultiHeadAttention(dim, num_head, dropout=0.1)
self.attn2 = MultiHeadAttention(dim, num_head, dropout=0.1)
self.ff = FeedForward(dim, ff_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, mem, x_mask, mem_mask):
x1 = self.attn1(x, x, x, x_mask) # self-attention
x1 = x + self.dropout1(x1)
x = self.norm1(x1)
if mem is not None:
x2 = self.attn2(x, mem, mem, mem_mask) # encoder input
x2 = x + self.dropout2(x2)
x = self.norm2(x2)
x3 = self.ff(x)
x3 = x + self.dropout3(x3)
x = self.norm3(x3)
return x
def forward_last_one(self, x, mem, mem_mask):
x_one = x[:, [-1]]
x1 = self.attn1(x_one, x, x) # self-attention
x_one = x_one + x1
x_one = self.norm1(x_one)
if mem is not None:
x2 = self.attn2(x_one, mem, mem, mem_mask) # encoder input
x_one = x_one + x2
x_one = self.norm2(x_one)
x3 = self.ff(x_one)
x_one = x_one + x3
x_one = self.norm3(x_one)
return x_one
# ------------------------------------------------------
# https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
# https://stackoverflow.com/questions/46452020/sinusoidal-embedding-attention-is-all-you-need
# class PositionEncode1D(nn.Module):
# def __init__(self, dim, length):
# super().__init__()
#
# def forward(self, x):
# return x
class PositionEncode1D(nn.Module):
def __init__(self, dim, max_length):
super().__init__()
assert (dim % 2 == 0)
self.max_length = max_length
d = torch.exp(torch.arange(0., dim, 2)* (-math.log(10000.0) / dim))
position = torch.arange(0., max_length).unsqueeze(1)
pos = torch.zeros(1, max_length, dim)
pos[0, :, 0::2] = torch.sin(position * d)
pos[0, :, 1::2] = torch.cos(position * d)
self.register_buffer('pos', pos)
def forward(self, x):
batch_size, T, dim = x.shape
x = x + self.pos[:,:T]
return x
#https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
class PositionEncode2D(nn.Module):
def __init__(self, dim, width, height):
super().__init__()
assert (dim % 4 == 0)
self.width = width
self.height = height
dim = dim//2
d = torch.exp(torch.arange(0., dim, 2) * -(math.log(10000.0) / dim))
position_w = torch.arange(0., width ).unsqueeze(1)
position_h = torch.arange(0., height).unsqueeze(1)
pos = torch.zeros(1, dim*2, height, width)
pos[0, 0:dim:2, :, :] = torch.sin(position_w * d).transpose(0, 1).unsqueeze(1).repeat(1,1, height, 1)
pos[0, 1:dim:2, :, :] = torch.cos(position_w * d).transpose(0, 1).unsqueeze(1).repeat(1,1, height, 1)
pos[0,dim + 0: :2, :, :] = torch.sin(position_h * d).transpose(0, 1).unsqueeze(2).repeat(1,1, 1, width)
pos[0,dim + 1: :2, :, :] = torch.cos(position_h * d).transpose(0, 1).unsqueeze(2).repeat(1,1, 1, width)
self.register_buffer('pos', pos)
def forward(self, x):
batch_size,C,H,W = x.shape
x = x + self.pos[:,:,:H,:W]
return x
# pos = PositionEncode(dim=128)
# relative_time = torch.rand(10,5,1)
# pos(relative_time)
# exit(0)
# ------------------------------------
'''
mask
array([[[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=uint8)
'''
# def triangle_mask(size):
# mask = np.triu(np.ones((1, size, size)),k=1).astype('uint8')
# mask = torch.autograd.Variable(torch.from_numpy(mask) == 0)
# return mask
# #triangle_mask(10)
#
# https://github.com/alexmt-scale/causal-transformer-decoder/blob/master/causal_transformer_decoder/model.py
class TransformerDecode(nn.Module):
def __init__(self, dim, ff_dim, num_head, num_layer):
super().__init__()
self.num_layer = num_layer
self.layer = nn.ModuleList([
TransformerDecodeLayer(dim, ff_dim, num_head) for i in range(num_layer)
])
self.norm = Norm(dim)
def forward(self, x, mem, x_mask=None, mem_mask=None):
for i in range(self.num_layer):
x = self.layer[i](x, mem, x_mask, mem_mask)
return x
def forward_last_one(self, x, mem, mem_mask=None, cache=None):
xx = []
for i in range(self.num_layer):
x = self.layer[i].forward_last_one(x, mem, mem_mask)
xx.append(x)
if cache is not None:
x = torch.cat([cache[i], x], dim=1)
if cache is not None:
new_cache = torch.cat([cache, torch.stack(xx, dim=0)], dim=2)
else:
new_cache = torch.stack(xx, dim=0) #num_layer, batch_size,length,dim
return x, new_cache
# check ################################################################
# https://github.com/alexmt-scale/causal-transformer-decoder/blob/master/tests/test_consistency.py
def run_check_fast_decode():
batch_size = 2
length=6
dim = 4
num_head = 2
ff_dim = dim * num_head
num_layer = 1
decoder = TransformerDecode(dim, ff_dim, num_head, num_layer)
decoder.eval()
#----
mem = torch.rand(batch_size, 5, dim)
first_x = torch.rand(batch_size, 1, dim)
#----
x1 = first_x
for t in range(length - 1):
# create mask for autoregressive decoding
mask = 1 - np.triu(np.ones((batch_size, (t+1), (t+1))), k=1).astype(np.uint8)
mask = torch.autograd.Variable(torch.from_numpy(mask))
y = decoder( x1, mem, x_mask=mask )
x1 = torch.cat( [x1, y[:,-1:]], dim=1)
print(x1)
print(x1.shape)
#----
cache = None
x2 = first_x
for t in range(length - 1):
y, cache = decoder.forward_last_one( x2, mem, cache=cache )
x2 = torch.cat( [x2, y[:,-1:]], dim=1)
print(x2)
print(x2.shape)
print(torch.eq(x1, x2))
diff = torch.abs(x1-x2)
print(diff)
print(diff.max(),diff.min())
# main #################################################################
if __name__ == '__main__':
run_check_fast_decode()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment