Skip to content

Instantly share code, notes, and snippets.

@theeluwin
Last active January 18, 2022 19:36
Show Gist options
  • Save theeluwin/5fc65304d74407e7c20bb71110bc87cf to your computer and use it in GitHub Desktop.
Save theeluwin/5fc65304d74407e7c20bb71110bc87cf to your computer and use it in GitHub Desktop.
Well-documented Transformer
import torch
import torch.nn as nn
from typing import Optional
from math import (
pi,
sqrt,
)
from torch import Tensor
from torch.nn.functional import softmax
__all__ = (
'GELU',
'LayerNorm',
'Attention',
'MultiHeadedAttention',
'SublayerConnection',
'PositionWiseFeedForward',
'Transformer',
)
class GELU(nn.Module):
def forward(self, x: Tensor):
return 0.5 * x * (1 + torch.tanh(sqrt(2 / pi) * (x + 0.044715 * torch.pow(x, 3))))
class LayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
# params
self.dim = dim
self.eps = eps
# layers
self.alpha = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x: Tensor):
mu = x.mean(-1, keepdim=True)
sigma = x.std(-1, keepdim=True)
return self.alpha * (x - mu) / (sigma + self.eps) + self.beta
class Attention(nn.Module):
def forward(self,
Q: Tensor,
K: Tensor,
V: Tensor,
mask: Optional[Tensor] = None,
dropout: Optional[nn.Module] = None
):
"""
Q: (b x ? x L x dim_Q)
K: (b x ? x L x dim_K)
V: (b x ? x L x dim_V)
?: 1 (squeezed) or h (multi-head)
mask: (b x ? x L x L)
dropout: nn.Module
assuming dim_Q = dim_K
"""
dim_Q = Q.size(-1)
# A: (b x ? x L x L)
A = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(dim_Q)
# apply mask (the logit value of a padding token should be minus infinity)
if mask is not None:
A = A.masked_fill(mask == 0, -1e9) # tip: `mask is False` does not invoke broadcasting
# getting normalized(probability) weights through softmax (when padding token, it'll be 0)
# P: (b x ? x L x L)
P = softmax(A, dim=-1)
# apply dropout (with given dropout)
if dropout is not None:
P = dropout(P)
# (b x ? x L x L) @ (b x ? x L x dim_V) -> (b x ? x L x dim_V)
x = torch.matmul(P, V)
return x, P
class MultiHeadedAttention(nn.Module):
def __init__(self,
num_heads: int,
dim_model: int,
dropout_prob: float = 0.1
):
"""
dim_K should be equal to dim_model / num_heads
we assume dim_Q = dim_K = dim_V
"""
super().__init__()
assert dim_model % num_heads == 0
# params
self.dim_model = dim_model
self.num_heads = num_heads
self.dropout_prob = dropout_prob
# splitted dim_K
self.dim_K = dim_model // num_heads
# layers
self.W_Q = nn.Linear(dim_model, dim_model)
self.W_K = nn.Linear(dim_model, dim_model)
self.W_V = nn.Linear(dim_model, dim_model)
self.W_M = nn.Linear(dim_model, dim_model)
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout_prob)
def forward(self,
Q: Tensor,
K: Tensor,
V: Tensor,
mask: Optional[Tensor] = None
):
b = Q.size(0)
# 1) Do all the linear projections in a batch from dim_model, then split into (num_heads x dim_K)
# [process]
# (1) linear(W): (b x L x dim_model) -> (b x L x dim_model)
# (2) view: (b x L x dim_model) -> (b x L x num_heads x dim_K)
# (3) transpose: (b x L x num_heads x dim_K) -> (b x num_heads x L x dim_K)
Q = self.W_Q(Q).view(b, -1, self.h, self.dim_K).transpose(1, 2)
K = self.W_K(K).view(b, -1, self.h, self.dim_K).transpose(1, 2)
V = self.W_V(V).view(b, -1, self.h, self.dim_K).transpose(1, 2)
# 2) Apply attention to the projected vectors in the batch
# note that attenion only cares about the last two dimensions
# x: (b x num_heads x L x dim_K)
x, _ = self.attention(Q, K, V, mask=mask, dropout=self.dropout)
# 3) "concat" those heads using view
# [process]
# (1) transpose: (b x num_heads x L x dim_K) -> (b x L x num_heads x dim_K)
# (2) contiguous: reorder memory inside GPU (no dimension change)
# (3) view: (b x L x num_heads x dim_K) -> (b x L x dim_model)
x = x.transpose(1, 2).contiguous().view(b, -1, self.dim_model)
# 4) apply the final linear
# x: (b x L x dim_model)
x = self.W_M(x)
return x
class SublayerConnection(nn.Module):
def __init__(self, dim: int = 256, dropout_prob: float = 0.1):
super().__init__()
# params
self.dim = dim
self.dropout_prob = dropout_prob
# layers
self.layernorm = LayerNorm(dim)
self.dropout = nn.Dropout(p=dropout_prob)
def forward(self, x: Tensor, sublayer: nn.Module):
r = self.layernorm(x)
r = sublayer(r)
r = self.dropout(r)
return x + r
class PositionWiseFeedForward(nn.Module):
def __init__(self,
dim_model: int = 256,
dim_ff: int = 1024,
dropout_prob: float = 0.1
):
super().__init__()
# params
self.dim_model = dim_model
self.dim_ff = dim_ff
self.dropout_prob = dropout_prob
# layers
self.W_1 = nn.Linear(dim_model, dim_ff)
self.W_2 = nn.Linear(dim_ff, dim_model)
self.dropout = nn.Dropout(p=dropout_prob)
self.gelu = GELU()
def forward(self, x: Tensor):
x = self.W_1(x) # (b x dim_model) -> (b x dim_ff)
x = self.gelu(x)
x = self.dropout(x)
x = self.W_2(x) # (b x dim_ff) -> (b x dim_model)
return x
class Transformer(nn.Module):
def __init__(self,
dim_model: int = 256,
num_heads: int = 4,
dim_ff: int = 1024,
dropout_prob: float = 0.1
):
super().__init__()
# params
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_ff = dim_ff
self.dropout_prob = dropout_prob
# layers
self.attention = MultiHeadedAttention(num_heads=num_heads, dim_model=dim_model, dropout_prob=dropout_prob)
self.attention_sublayer = SublayerConnection(dim=dim_model, dropout_prob=dropout_prob)
self.pwff = PositionWiseFeedForward(dim_model=dim_model, dim_ff=dim_ff, dropout_prob=dropout_prob)
self.pwff_sublayer = SublayerConnection(dim=dim_model, dropout_prob=dropout_prob)
self.dropout = nn.Dropout(p=dropout_prob)
def forward(self, x: Tensor, mask: Optional[Tensor] = None):
# we need dynamic mask for the attention forward (sublayer module also has parameters, namely layernorm)
# x: (b x L x dim_model)
# mask: (b x L x L), set False to ignore that point
x = self.attention_sublayer(x, lambda z: self.attention.forward(z, z, z, mask=mask))
x = self.pwff_sublayer(x, self.pwff)
x = self.dropout(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment