Skip to content

Instantly share code, notes, and snippets.

@eleganceinsimplicity
Last active December 23, 2022 03:12
Show Gist options
  • Save eleganceinsimplicity/c37eb04ea6db22d8e2d6542640437629 to your computer and use it in GitHub Desktop.
Save eleganceinsimplicity/c37eb04ea6db22d8e2d6542640437629 to your computer and use it in GitHub Desktop.
simple_transformer_rel_pos_embd_model
import math
import torch
from torch import nn
import torch.nn.functional as F
class SimpleTransformerClassifierWithRelPosAttnEinsum(nn.Module):
"""
Transformer Encoder based Classifier for Sentiment Analysis
"""
def __init__(self, vocab_size, embd_dim, num_classes, ff_dims=1208, num_heads=2, num_layers=4,
input_dropout=0.28,
mha_dropout=0.18,
ff_dropout=0.14,
rel_attn_dropout=0.11, rel_attn_combined_dropout=0.41,
padding_idx=None):
super(SimpleTransformerClassifierWithRelPosAttnEinsum, self).__init__()
self.padding_idx = padding_idx
self.embd_layer = nn.Embedding(vocab_size, embd_dim, padding_idx)
self.num_classes = num_classes
self.transformer_encoder = TransformerEncoder(self.embd_layer, embd_dim, feedforward_dims=ff_dims,
num_heads=num_heads, num_layers=num_layers,
input_dropout=input_dropout,
mha_dropout=mha_dropout, ff_dropout=ff_dropout,
rel_attn_dropout=rel_attn_dropout,
rel_attn_combined_dropout=rel_attn_combined_dropout,
max_seq_len=1024)
self.additive_attention = AdditiveAttention(embd_dim)
self.pred_layer = nn.Sequential(
nn.Linear(embd_dim, embd_dim),
nn.LeakyReLU(),
nn.BatchNorm1d(embd_dim),
nn.Linear(embd_dim, self.num_classes)
)
def forward(self, input):
if self.padding_idx is not None:
mask = (input != self.padding_idx)
else:
mask = (input == input)
x = self.transformer_encoder(input, mask)
context = x.sum(dim=1) / mask.sum(dim=1).unsqueeze(1)
x = self.additive_attention(x, context, mask)
x = self.pred_layer(x)
return x
class TransformerEncoder(nn.Module):
def __init__(self, embedding, embd_dims, feedforward_dims,
num_heads, num_layers, input_dropout,
mha_dropout, ff_dropout, rel_attn_dropout,
rel_attn_combined_dropout,
max_seq_len=1024):
super(TransformerEncoder, self).__init__()
self.embedding_layer = embedding
self.hidden_dim = embd_dims
self.feedforward_dims = feedforward_dims
self.num_heads = num_heads
self.num_layers = num_layers
self.positional_encoding = PositionalEncoding(d_model=self.hidden_dim, dropout=input_dropout, max_len=5000,
batch_first=True)
self.dropout_layer = nn.Dropout(p=input_dropout)
self.encoder_blocks = nn.ModuleList(
[
EncoderBlock(self.hidden_dim, self.feedforward_dims, self.num_heads, mha_dropout, ff_dropout,
rel_attn_dropout,
rel_attn_combined_dropout, max_seq_len)
for _ in range(self.num_layers)
]
)
def forward(self, input, padding_mask=None):
x = (self.embedding_layer(input) * (self.hidden_dim ** (0.5)))
x = self.positional_encoding(x)
x = self.dropout_layer(x)
for encoder_block in self.encoder_blocks:
x = encoder_block.forward(x, padding_mask)
return x
class EncoderBlock(nn.Module):
"""
Norm / Residual Connections adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py#L532
"""
def __init__(self, hidden_dim, feedforward_dim, num_heads, mha_dropout, ff_dropout, rel_attn_dropout,
rel_attn_combined_dropout, max_seq_len):
super(EncoderBlock, self).__init__()
self.norm_first = True
self.self_mha_rel = RelativeGlobalMultiHeadAttention(hidden_dim, num_heads, rel_attn_dropout,
rel_attn_combined_dropout, max_seq_len)
self.ff_layer = nn.Sequential(
nn.Linear(hidden_dim, feedforward_dim),
nn.ReLU(),
nn.Linear(feedforward_dim, hidden_dim)
)
self.dropout_mha_rel_layer = nn.Dropout(mha_dropout)
self.dropout_ff_layer = nn.Dropout(ff_dropout)
self.layer_norm1 = nn.LayerNorm(hidden_dim)
self.layer_norm2 = nn.LayerNorm(hidden_dim)
def forward(self, input, padding_mask):
x = input
if self.norm_first:
x = x + self._sa_block(self.layer_norm1(x), padding_mask)
x = x + self._ff_block(self.layer_norm2(x))
else:
x = self.layer_norm1(x + self._sa_block(x, padding_mask))
x = self.layer_norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(self, x, padding_mask):
x = self.self_mha_rel.forward(x, padding_mask)
return self.dropout_mha_rel_layer(x)
# feed forward block
def _ff_block(self, x):
x = self.ff_layer(x)
return self.dropout_ff_layer(x)
class RelativeGlobalMultiHeadAttention(nn.Module):
"""
Adapted from https://github.com/chathasphere/pno-ai/blob/master/model/attention.py
"""
def __init__(self, hidden_dim, num_heads, rel_attn_dropout,
rel_attn_combined_dropout, max_seq_len):
super(RelativeGlobalMultiHeadAttention, self).__init__()
if hidden_dim % num_heads != 0:
raise NameError("Model dimensions must be divisible by the number of heads.")
self.max_seq_len = max_seq_len
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.dims_per_head = hidden_dim // num_heads
self.relative_pos_use_query_key = True
self.query_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
self.key_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
self.value_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
# input is the position of the token within the sequence, pos vocabulary = 2 * max_seq_len, as we have a matrix
# see below for illustration
self.distance_embedding = nn.Embedding((2 * max_seq_len) - 1,
embedding_dim=self.dims_per_head) # ( ( 2 * T) - 1, D)
self.combined_projection = nn.Linear(self.dims_per_head * self.num_heads, hidden_dim)
self.attn_dropout = nn.Dropout(rel_attn_dropout)
self.combined_projection_dropout = nn.Dropout(rel_attn_combined_dropout)
def forward(self, input, mask):
batch_size, timesteps, hidden_dim = input.shape
attn_query_slice = self.transpose_for_scores(self.query_layer(input)) # (B,H,T,s)
attn_key_slice = self.transpose_for_scores(self.key_layer(input)) # (B,H,T,s)
attn_value_slice = self.transpose_for_scores(self.value_layer(input)) # (B,H,T,s)
query_length, key_length = attn_query_slice.shape[2], attn_key_slice.shape[2] # q_l = s, k_l = s
attention_scores = torch.matmul(attn_query_slice, attn_key_slice.transpose(-2, -1)) / (
self.dims_per_head ** (0.25)) # ( B, H, T, T)
"""
Relative Positional Embeddings adapted from huggingface -> Bert transformer
https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L343
"""
position_ids_l = torch.arange(query_length, dtype=torch.long, device=input.device).view(-1,
1) # row wise vector 2D vector(T, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=input.device).view(1,
-1) # column wise 2D vector (1, T)
distance = position_ids_l - position_ids_r # matrix of (T,T)
"""
Example of l -> max_seq_len = 8, (T, 1)
[[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7]]
Example of r ->
[[0, 1, 2, 3, 4, 5, 6, 7]] (1, T)
distance -> (T, T)
[[ 0, -1, -2, -3, -4, -5, -6, -7],
[ 1, 0, -1, -2, -3, -4, -5, -6],
[ 2, 1, 0, -1, -2, -3, -4, -5],
[ 3, 2, 1, 0, -1, -2, -3, -4],
[ 4, 3, 2, 1, 0, -1, -2, -3],
[ 5, 4, 3, 2, 1, 0, -1, -2],
[ 6, 5, 4, 3, 2, 1, 0, -1],
[ 7, 6, 5, 4, 3, 2, 1, 0]]
addition of max_seq_len to the distance gives the relative positions for each element with positive relative distances
each row is the current position of the attn item, each index on that row, gives the distance b/t the current token
and the indexed token
for example when processing timestep=1, position = 0, row = 0, distance b/t itself and 5 timestep(index=4)
the (7-3) = 4 words apart
[[ 7, 6, 5, 4, 3, 2, 1, 0],
[ 8, 7, 6, 5, 4, 3, 2, 1],
[ 9, 8, 7, 6, 5, 4, 3, 2],
[10, 9, 8, 7, 6, 5, 4, 3],
[11, 10, 9, 8, 7, 6, 5, 4],
[12, 11, 10, 9, 8, 7, 6, 5],
[13, 12, 11, 10, 9, 8, 7, 6],
[14, 13, 12, 11, 10, 9, 8, 7]]
"""
positional_embedding = self.distance_embedding(distance + (self.max_seq_len - 1)) # (T,T, s)
# relative key and query positions
# einsum is used due to the optimizations built in - uses opt_einsum
# https://pytorch.org/docs/stable/generated/torch.einsum.html
# https://optimized-einsum.readthedocs.io/en/stable/
if self.relative_pos_use_query_key:
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", attn_query_slice,
positional_embedding) # (B,H,T,s) @ (T,T,s) -> (B,H,T,T)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", attn_key_slice,
positional_embedding) # (B,H,T,s) @ (T,T,s) -> (B,H,T,T)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key # (B, H, T, T)
else:
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", attn_query_slice,
positional_embedding) # (B,H,T,s) @ (T,T,s) -> (B,H,T,T)
attention_scores = attention_scores + relative_position_scores_query # (B, H, T, T)
if mask is not None:
mask_value = -1e9
assert mask.shape[-1] == attention_scores.shape[-1], 'mask dimensions does not match embd dimensions'
mask = mask[:, None, :] * mask[:, :, None] # (B,1,T) * (B,T, 1) -> (B,T,T)
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # (B,T,T) -> (B,H, T, T)
attention_scores.masked_fill_(~mask, mask_value)
attn_prob_dist = F.softmax(attention_scores, dim=-1)
attn_prob_dist = self.attn_dropout(attn_prob_dist)
x = torch.matmul(attn_prob_dist,
attn_value_slice) # (B, H, T, T ) @ (B, H ,T, s) -> (B,H,T, s), using broadcasting
x = x.transpose(1, 2).reshape(batch_size, timesteps, hidden_dim) # (B,T, H, S) -> (B,T, D)
x = self.combined_projection(x)
x = self.combined_projection_dropout(x)
return x
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_heads, self.dims_per_head) # (B,T, H, s)
x = x.view(new_x_shape) # (B,T, H, s)
return x.permute(0, 2, 1, 3) # (B,H,T,s)
class AdditiveAttention(nn.Module):
"""
Adapted from https://github.com/EdwardRaff/Inside-Deep-Learning/blob/bc4dccf13a4711fa681a169ff1ae3a4f07c6ae69/idlmam.py#L389
"""
def __init__(self, embd_dims):
super(AdditiveAttention, self).__init__()
self.embd_dims = embd_dims
self.score_layer = nn.Sequential(
nn.Linear(2 * self.embd_dims, self.embd_dims),
nn.Tanh(),
nn.Linear(embd_dims, 1)
)
def forward(self, attn_states, avg_context, mask=None):
batch_size = attn_states.size(0)
time_steps = attn_states.size(1)
dims = attn_states.size(2)
avg_context = torch.stack([avg_context for _ in range(time_steps)], dim=1)
merged_attn_context = torch.cat((attn_states, avg_context), dim=2)
scores = self.score_layer(merged_attn_context)
if mask is not None:
scores[~mask] = float(-10000)
weights = F.softmax(scores, dim=1)
final_context = (attn_states * weights).sum(dim=1)
return final_context.view(batch_size, dims)
class PositionalEncoding(nn.Module):
"""
Adapted from https://github.com/pytorch/examples/blob/0c1654d6913f77f09c0505fb284d977d89c17c1a/word_language_model/model.py#L63
"""
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.batch_first = batch_first
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
if self.batch_first:
x = x.permute(1, 0, 2)
x = x + self.pe[:x.size(0), :]
x = self.dropout(x)
if self.batch_first:
x = x.permute(1, 0, 2)
return x
def cur_device(tensor=None):
if tensor is None:
return 'cuda' if torch.cuda.is_available() else 'cpu'
return 'cuda' if tensor.is_cuda else 'cpu'
def initialize_weights(m):
if hasattr(m, 'weight') and m.weight.dim() > 1:
nn.init.xavier_uniform_(m.weight.data)
# nn.init.kaiming_uniform_(m.weight.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment