Skip to content

Instantly share code, notes, and snippets.

Created June 29, 2022 01:56
Show Gist options
  • Save jsksxs360/3ae3b176352fa78a4fca39fff0ffe648 to your computer and use it in GitHub Desktop.
Save jsksxs360/3ae3b176352fa78a4fca39fff0ffe648 to your computer and use it in GitHub Desktop.
Transformer Encoder implemented by Pytorch
import torch
from torch import nn
import torch.nn.functional as F
from math import sqrt
class AttentionHead(nn.Module):
def __init__(self, embed_dim, head_dim):
self.q = nn.Linear(embed_dim, head_dim)
self.k = nn.Linear(embed_dim, head_dim)
self.v = nn.Linear(embed_dim, head_dim)
def forward(self, query, key, value, mask=None):
query, key, value = self.q(query), self.k(key), self.v(value)
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(query.size(-1))
if mask is not None:
scores = scores.masked_fill(mask == 0, -float("inf"))
weights = F.softmax(scores, dim=-1)
return torch.bmm(weights, value)
class MultiHeadAttention(nn.Module):
def __init__(self, config):
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
self.heads = nn.ModuleList(
[AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
self.output_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, mask=None, query_mask=None, key_mask=None):
if query_mask is not None and key_mask is not None:
mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))
x =[h(query, key, value, mask) for h in self.heads], dim=-1)
x = self.output_linear(x)
return x
class FeedForward(nn.Module):
def __init__(self, config):
self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x):
x = self.linear_1(x)
x = self.gelu(x)
x = self.linear_2(x)
x = self.dropout(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(self, config):
self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
self.attention = MultiHeadAttention(config)
self.feed_forward = FeedForward(config)
def forward(self, x, mask=None):
# Apply layer normalization and then copy input into query, key, value
hidden_state = self.layer_norm_1(x)
# Apply attention with a skip connection
x = x + self.attention(hidden_state, hidden_state, hidden_state, mask=mask)
# Apply feed-forward layer with a skip connection
x = x + self.feed_forward(self.layer_norm_2(x))
return x
class Embeddings(nn.Module):
def __init__(self, config):
self.token_embeddings = nn.Embedding(config.vocab_size,
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout()
def forward(self, input_ids):
# Create position IDs for input sequence
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
# Create token and position embeddings
token_embeddings = self.token_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
# Combine token and position embeddings
embeddings = token_embeddings + position_embeddings
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TransformerEncoder(nn.Module):
def __init__(self, config):
self.embeddings = Embeddings(config)
self.layers = nn.ModuleList(
[TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
def forward(self, x, mask=None):
x = self.embeddings(x)
for layer in self.layers:
x = layer(x, mask)
return x
if __name__ == '__main__':
from transformers import AutoConfig
from transformers import AutoTokenizer
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
config = AutoConfig.from_pretrained(model_ckpt)
text = "time flies like an arrow"
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
encoder = TransformerEncoder(config)
Copy link

why do you not use n+1 dimension for 'multihead'-attention against primitive attention, but not multihead wrapperd attention-head?

Copy link

why do you not use n+1 dimension for 'multihead'-attention against primitive attention, but not multihead wrapperd attention-head?


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment