Last active
December 23, 2022 03:12
-
-
Save eleganceinsimplicity/c37eb04ea6db22d8e2d6542640437629 to your computer and use it in GitHub Desktop.
simple_transformer_rel_pos_embd_model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class SimpleTransformerClassifierWithRelPosAttnEinsum(nn.Module): | |
""" | |
Transformer Encoder based Classifier for Sentiment Analysis | |
""" | |
def __init__(self, vocab_size, embd_dim, num_classes, ff_dims=1208, num_heads=2, num_layers=4, | |
input_dropout=0.28, | |
mha_dropout=0.18, | |
ff_dropout=0.14, | |
rel_attn_dropout=0.11, rel_attn_combined_dropout=0.41, | |
padding_idx=None): | |
super(SimpleTransformerClassifierWithRelPosAttnEinsum, self).__init__() | |
self.padding_idx = padding_idx | |
self.embd_layer = nn.Embedding(vocab_size, embd_dim, padding_idx) | |
self.num_classes = num_classes | |
self.transformer_encoder = TransformerEncoder(self.embd_layer, embd_dim, feedforward_dims=ff_dims, | |
num_heads=num_heads, num_layers=num_layers, | |
input_dropout=input_dropout, | |
mha_dropout=mha_dropout, ff_dropout=ff_dropout, | |
rel_attn_dropout=rel_attn_dropout, | |
rel_attn_combined_dropout=rel_attn_combined_dropout, | |
max_seq_len=1024) | |
self.additive_attention = AdditiveAttention(embd_dim) | |
self.pred_layer = nn.Sequential( | |
nn.Linear(embd_dim, embd_dim), | |
nn.LeakyReLU(), | |
nn.BatchNorm1d(embd_dim), | |
nn.Linear(embd_dim, self.num_classes) | |
) | |
def forward(self, input): | |
if self.padding_idx is not None: | |
mask = (input != self.padding_idx) | |
else: | |
mask = (input == input) | |
x = self.transformer_encoder(input, mask) | |
context = x.sum(dim=1) / mask.sum(dim=1).unsqueeze(1) | |
x = self.additive_attention(x, context, mask) | |
x = self.pred_layer(x) | |
return x | |
class TransformerEncoder(nn.Module): | |
def __init__(self, embedding, embd_dims, feedforward_dims, | |
num_heads, num_layers, input_dropout, | |
mha_dropout, ff_dropout, rel_attn_dropout, | |
rel_attn_combined_dropout, | |
max_seq_len=1024): | |
super(TransformerEncoder, self).__init__() | |
self.embedding_layer = embedding | |
self.hidden_dim = embd_dims | |
self.feedforward_dims = feedforward_dims | |
self.num_heads = num_heads | |
self.num_layers = num_layers | |
self.positional_encoding = PositionalEncoding(d_model=self.hidden_dim, dropout=input_dropout, max_len=5000, | |
batch_first=True) | |
self.dropout_layer = nn.Dropout(p=input_dropout) | |
self.encoder_blocks = nn.ModuleList( | |
[ | |
EncoderBlock(self.hidden_dim, self.feedforward_dims, self.num_heads, mha_dropout, ff_dropout, | |
rel_attn_dropout, | |
rel_attn_combined_dropout, max_seq_len) | |
for _ in range(self.num_layers) | |
] | |
) | |
def forward(self, input, padding_mask=None): | |
x = (self.embedding_layer(input) * (self.hidden_dim ** (0.5))) | |
x = self.positional_encoding(x) | |
x = self.dropout_layer(x) | |
for encoder_block in self.encoder_blocks: | |
x = encoder_block.forward(x, padding_mask) | |
return x | |
class EncoderBlock(nn.Module): | |
""" | |
Norm / Residual Connections adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py#L532 | |
""" | |
def __init__(self, hidden_dim, feedforward_dim, num_heads, mha_dropout, ff_dropout, rel_attn_dropout, | |
rel_attn_combined_dropout, max_seq_len): | |
super(EncoderBlock, self).__init__() | |
self.norm_first = True | |
self.self_mha_rel = RelativeGlobalMultiHeadAttention(hidden_dim, num_heads, rel_attn_dropout, | |
rel_attn_combined_dropout, max_seq_len) | |
self.ff_layer = nn.Sequential( | |
nn.Linear(hidden_dim, feedforward_dim), | |
nn.ReLU(), | |
nn.Linear(feedforward_dim, hidden_dim) | |
) | |
self.dropout_mha_rel_layer = nn.Dropout(mha_dropout) | |
self.dropout_ff_layer = nn.Dropout(ff_dropout) | |
self.layer_norm1 = nn.LayerNorm(hidden_dim) | |
self.layer_norm2 = nn.LayerNorm(hidden_dim) | |
def forward(self, input, padding_mask): | |
x = input | |
if self.norm_first: | |
x = x + self._sa_block(self.layer_norm1(x), padding_mask) | |
x = x + self._ff_block(self.layer_norm2(x)) | |
else: | |
x = self.layer_norm1(x + self._sa_block(x, padding_mask)) | |
x = self.layer_norm2(x + self._ff_block(x)) | |
return x | |
# self-attention block | |
def _sa_block(self, x, padding_mask): | |
x = self.self_mha_rel.forward(x, padding_mask) | |
return self.dropout_mha_rel_layer(x) | |
# feed forward block | |
def _ff_block(self, x): | |
x = self.ff_layer(x) | |
return self.dropout_ff_layer(x) | |
class RelativeGlobalMultiHeadAttention(nn.Module): | |
""" | |
Adapted from https://github.com/chathasphere/pno-ai/blob/master/model/attention.py | |
""" | |
def __init__(self, hidden_dim, num_heads, rel_attn_dropout, | |
rel_attn_combined_dropout, max_seq_len): | |
super(RelativeGlobalMultiHeadAttention, self).__init__() | |
if hidden_dim % num_heads != 0: | |
raise NameError("Model dimensions must be divisible by the number of heads.") | |
self.max_seq_len = max_seq_len | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.dims_per_head = hidden_dim // num_heads | |
self.relative_pos_use_query_key = True | |
self.query_layer = nn.Linear(self.hidden_dim, self.hidden_dim) | |
self.key_layer = nn.Linear(self.hidden_dim, self.hidden_dim) | |
self.value_layer = nn.Linear(self.hidden_dim, self.hidden_dim) | |
# input is the position of the token within the sequence, pos vocabulary = 2 * max_seq_len, as we have a matrix | |
# see below for illustration | |
self.distance_embedding = nn.Embedding((2 * max_seq_len) - 1, | |
embedding_dim=self.dims_per_head) # ( ( 2 * T) - 1, D) | |
self.combined_projection = nn.Linear(self.dims_per_head * self.num_heads, hidden_dim) | |
self.attn_dropout = nn.Dropout(rel_attn_dropout) | |
self.combined_projection_dropout = nn.Dropout(rel_attn_combined_dropout) | |
def forward(self, input, mask): | |
batch_size, timesteps, hidden_dim = input.shape | |
attn_query_slice = self.transpose_for_scores(self.query_layer(input)) # (B,H,T,s) | |
attn_key_slice = self.transpose_for_scores(self.key_layer(input)) # (B,H,T,s) | |
attn_value_slice = self.transpose_for_scores(self.value_layer(input)) # (B,H,T,s) | |
query_length, key_length = attn_query_slice.shape[2], attn_key_slice.shape[2] # q_l = s, k_l = s | |
attention_scores = torch.matmul(attn_query_slice, attn_key_slice.transpose(-2, -1)) / ( | |
self.dims_per_head ** (0.25)) # ( B, H, T, T) | |
""" | |
Relative Positional Embeddings adapted from huggingface -> Bert transformer | |
https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L343 | |
""" | |
position_ids_l = torch.arange(query_length, dtype=torch.long, device=input.device).view(-1, | |
1) # row wise vector 2D vector(T, 1) | |
position_ids_r = torch.arange(key_length, dtype=torch.long, device=input.device).view(1, | |
-1) # column wise 2D vector (1, T) | |
distance = position_ids_l - position_ids_r # matrix of (T,T) | |
""" | |
Example of l -> max_seq_len = 8, (T, 1) | |
[[0], | |
[1], | |
[2], | |
[3], | |
[4], | |
[5], | |
[6], | |
[7]] | |
Example of r -> | |
[[0, 1, 2, 3, 4, 5, 6, 7]] (1, T) | |
distance -> (T, T) | |
[[ 0, -1, -2, -3, -4, -5, -6, -7], | |
[ 1, 0, -1, -2, -3, -4, -5, -6], | |
[ 2, 1, 0, -1, -2, -3, -4, -5], | |
[ 3, 2, 1, 0, -1, -2, -3, -4], | |
[ 4, 3, 2, 1, 0, -1, -2, -3], | |
[ 5, 4, 3, 2, 1, 0, -1, -2], | |
[ 6, 5, 4, 3, 2, 1, 0, -1], | |
[ 7, 6, 5, 4, 3, 2, 1, 0]] | |
addition of max_seq_len to the distance gives the relative positions for each element with positive relative distances | |
each row is the current position of the attn item, each index on that row, gives the distance b/t the current token | |
and the indexed token | |
for example when processing timestep=1, position = 0, row = 0, distance b/t itself and 5 timestep(index=4) | |
the (7-3) = 4 words apart | |
[[ 7, 6, 5, 4, 3, 2, 1, 0], | |
[ 8, 7, 6, 5, 4, 3, 2, 1], | |
[ 9, 8, 7, 6, 5, 4, 3, 2], | |
[10, 9, 8, 7, 6, 5, 4, 3], | |
[11, 10, 9, 8, 7, 6, 5, 4], | |
[12, 11, 10, 9, 8, 7, 6, 5], | |
[13, 12, 11, 10, 9, 8, 7, 6], | |
[14, 13, 12, 11, 10, 9, 8, 7]] | |
""" | |
positional_embedding = self.distance_embedding(distance + (self.max_seq_len - 1)) # (T,T, s) | |
# relative key and query positions | |
# einsum is used due to the optimizations built in - uses opt_einsum | |
# https://pytorch.org/docs/stable/generated/torch.einsum.html | |
# https://optimized-einsum.readthedocs.io/en/stable/ | |
if self.relative_pos_use_query_key: | |
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", attn_query_slice, | |
positional_embedding) # (B,H,T,s) @ (T,T,s) -> (B,H,T,T) | |
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", attn_key_slice, | |
positional_embedding) # (B,H,T,s) @ (T,T,s) -> (B,H,T,T) | |
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key # (B, H, T, T) | |
else: | |
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", attn_query_slice, | |
positional_embedding) # (B,H,T,s) @ (T,T,s) -> (B,H,T,T) | |
attention_scores = attention_scores + relative_position_scores_query # (B, H, T, T) | |
if mask is not None: | |
mask_value = -1e9 | |
assert mask.shape[-1] == attention_scores.shape[-1], 'mask dimensions does not match embd dimensions' | |
mask = mask[:, None, :] * mask[:, :, None] # (B,1,T) * (B,T, 1) -> (B,T,T) | |
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # (B,T,T) -> (B,H, T, T) | |
attention_scores.masked_fill_(~mask, mask_value) | |
attn_prob_dist = F.softmax(attention_scores, dim=-1) | |
attn_prob_dist = self.attn_dropout(attn_prob_dist) | |
x = torch.matmul(attn_prob_dist, | |
attn_value_slice) # (B, H, T, T ) @ (B, H ,T, s) -> (B,H,T, s), using broadcasting | |
x = x.transpose(1, 2).reshape(batch_size, timesteps, hidden_dim) # (B,T, H, S) -> (B,T, D) | |
x = self.combined_projection(x) | |
x = self.combined_projection_dropout(x) | |
return x | |
def transpose_for_scores(self, x): | |
new_x_shape = x.size()[:-1] + (self.num_heads, self.dims_per_head) # (B,T, H, s) | |
x = x.view(new_x_shape) # (B,T, H, s) | |
return x.permute(0, 2, 1, 3) # (B,H,T,s) | |
class AdditiveAttention(nn.Module): | |
""" | |
Adapted from https://github.com/EdwardRaff/Inside-Deep-Learning/blob/bc4dccf13a4711fa681a169ff1ae3a4f07c6ae69/idlmam.py#L389 | |
""" | |
def __init__(self, embd_dims): | |
super(AdditiveAttention, self).__init__() | |
self.embd_dims = embd_dims | |
self.score_layer = nn.Sequential( | |
nn.Linear(2 * self.embd_dims, self.embd_dims), | |
nn.Tanh(), | |
nn.Linear(embd_dims, 1) | |
) | |
def forward(self, attn_states, avg_context, mask=None): | |
batch_size = attn_states.size(0) | |
time_steps = attn_states.size(1) | |
dims = attn_states.size(2) | |
avg_context = torch.stack([avg_context for _ in range(time_steps)], dim=1) | |
merged_attn_context = torch.cat((attn_states, avg_context), dim=2) | |
scores = self.score_layer(merged_attn_context) | |
if mask is not None: | |
scores[~mask] = float(-10000) | |
weights = F.softmax(scores, dim=1) | |
final_context = (attn_states * weights).sum(dim=1) | |
return final_context.view(batch_size, dims) | |
class PositionalEncoding(nn.Module): | |
""" | |
Adapted from https://github.com/pytorch/examples/blob/0c1654d6913f77f09c0505fb284d977d89c17c1a/word_language_model/model.py#L63 | |
""" | |
r"""Inject some information about the relative or absolute position of the tokens | |
in the sequence. The positional encodings have the same dimension as | |
the embeddings, so that the two can be summed. Here, we use sine and cosine | |
functions of different frequencies. | |
.. math:: | |
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | |
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | |
\text{where pos is the word position and i is the embed idx) | |
Args: | |
d_model: the embed dim (required). | |
dropout: the dropout value (default=0.1). | |
max_len: the max. length of the incoming sequence (default=5000). | |
Examples: | |
>>> pos_encoder = PositionalEncoding(d_model) | |
""" | |
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
self.batch_first = batch_first | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
r"""Inputs of forward function | |
Args: | |
x: the sequence fed to the positional encoder model (required). | |
Shape: | |
x: [sequence length, batch size, embed dim] | |
output: [sequence length, batch size, embed dim] | |
Examples: | |
>>> output = pos_encoder(x) | |
""" | |
if self.batch_first: | |
x = x.permute(1, 0, 2) | |
x = x + self.pe[:x.size(0), :] | |
x = self.dropout(x) | |
if self.batch_first: | |
x = x.permute(1, 0, 2) | |
return x | |
def cur_device(tensor=None): | |
if tensor is None: | |
return 'cuda' if torch.cuda.is_available() else 'cpu' | |
return 'cuda' if tensor.is_cuda else 'cpu' | |
def initialize_weights(m): | |
if hasattr(m, 'weight') and m.weight.dim() > 1: | |
nn.init.xavier_uniform_(m.weight.data) | |
# nn.init.kaiming_uniform_(m.weight.data) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment