Skip to content

Instantly share code, notes, and snippets.

@SivilTaram
Created January 5, 2022 03:24
Show Gist options
  • Save SivilTaram/d40bae6e31422b18a11c1610131365a5 to your computer and use it in GitHub Desktop.
Save SivilTaram/d40bae6e31422b18a11c1610131365a5 to your computer and use it in GitHub Desktop.
Minimal Transformer Demo
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Config:
embed_drop = 0.1
residual_drop = 0.1
attn_drop = 0.1
n_layer = 3
n_head = 8
n_embed = 128
def __init__(self, vocab_size, seq_len, **kwargs):
self.vocab_size = vocab_size
self.seq_len = seq_len
for k, v in kwargs.items():
setattr(self, k, v)
class PositionalEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
positional_embedding = torch.zeros(config.seq_len, config.n_embed)
for pos in range(config.seq_len):
for i in range(0, config.n_embed, 2):
positional_embedding[pos, i] = math.sin(pos / (10000 ** ((2 * i) / config.n_embed)))
positional_embedding[pos, i] = math.cos(pos / (10000 ** ((2 * i + 2) / config.n_embed)))
positional_embedding = positional_embedding.unsqueeze(0)
self.register_buffer('positional_embedding', positional_embedding)
def forward(self, x):
x = x * math.sqrt(self.config.n_embed)
seq_len = x.size(1)
x = x + self.positional_embedding[:, :seq_len]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.d_k = config.n_embed // config.n_head
self.d_model = config.n_embed
self.h = config.n_head
self.q_linear = nn.Linear(self.d_model, self.d_model)
self.k_linear = nn.Linear(self.d_model, self.d_model)
self.v_linear = nn.Linear(self.d_model, self.d_model)
self.attn_dropout = nn.Dropout(config.attn_drop)
self.residual_dropout = nn.Dropout(config.residual_drop)
self.out_linear = nn.Linear(self.d_model, self.d_model)
mask = torch.tril(torch.ones(config.seq_len, config.seq_len).long()
.view(1, 1, config.seq_len, config.seq_len))
self.register_buffer("mask", mask)
def forward(self, q, k, v, mask, is_decoder=False):
# (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = q.size()
# perform linear operation and split into h heads
# (batch_size, seq_len, h, d_k) -> (batch_size, h, seq_len, d_k)
k = self.k_linear(k).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
q = self.q_linear(q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
v = self.v_linear(v).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
# (batch_size, h, seq_len, d_k) x (batch_size, h, d_k, seq_len) -> (batch_size, h, seq_len, seq_len)
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_k))
if is_decoder:
mask = mask.long() & self.mask[:, :, :seq_len, :seq_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
# (batch_size, h, seq_len, seq_len) x (batch_size, h, seq_len, d_k) -> (batch_size, h, seq_len, d_k)
output = attn @ v
# (batch_size, seq_len, h, d_k) -> (batch_size, seq_len, hidden_size)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
output = self.out_linear(output)
return output
class PointWiseFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.d_model = config.n_embed
self.d_ff = config.n_embed * 4
self.dropout = nn.Dropout(config.residual_drop)
self.linear_1 = nn.Linear(self.d_model, self.d_ff)
self.linear_2 = nn.Linear(self.d_ff, self.d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return x
class LayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.size = config.n_embed
self.alpha = nn.Parameter(torch.ones(self.size))
self.bias = nn.Parameter(torch.zeros(self.size))
self.eps = 1e-6
def forward(self, x):
norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
return norm
class EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.norm_1 = LayerNorm(config)
self.norm_2 = LayerNorm(config)
self.attn = MultiHeadAttention(config)
self.ff = PointWiseFeedForward(config)
self.dropout = nn.Dropout(config.residual_drop)
def forward(self, x, mask):
x = self.norm_1(x + self.dropout(self.attn(x, x, x, mask)))
x = self.norm_2(x + self.dropout(self.ff(x)))
return x
class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.norm_1 = LayerNorm(config)
self.norm_2 = LayerNorm(config)
self.norm_3 = LayerNorm(config)
self.attn = MultiHeadAttention(config)
self.cross_attn = MultiHeadAttention(config)
self.ff = PointWiseFeedForward(config)
self.dropout = nn.Dropout(config.residual_drop)
def forward(self, x, src_outputs, src_mask, tgt_mask):
x = self.norm_1(x + self.dropout(self.attn(x, x, x, tgt_mask, is_decoder=True)))
x = self.norm_2(x + self.dropout(self.cross_attn(x, src_outputs, src_outputs, src_mask)))
x = self.norm_3(x + self.dropout(self.ff(x)))
return x
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.n_layer = config.n_layer
self.word_embedding = nn.Embedding(config.vocab_size, config.n_embed)
self.positional_embedding = PositionalEmbedding(config)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.n_layer)])
self.norm = LayerNorm(config)
def forward(self, x, mask):
x = self.word_embedding(x)
x = self.positional_embedding(x)
for i in range(self.n_layer):
x = self.layers[i](x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, config):
super().__init__()
self.n_layer = config.n_layer
self.word_embedding = nn.Embedding(config.vocab_size, config.n_embed)
self.positional_embedding = PositionalEmbedding(config)
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layer)])
self.norm = LayerNorm(config)
def forward(self, x, src_outputs, src_mask, tgt_mask):
x = self.positional_embedding(self.word_embedding(x))
for i in range(self.n_layer):
x = self.layers[i](x, src_outputs, src_mask, tgt_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = Encoder(config)
self.decoder = Decoder(config)
self.out = nn.Linear(config.n_embed, config.vocab_size)
def forward(self, src, tgt, src_mask, tgt_mask):
src_outputs = self.encoder(src, src_mask)
tgt_outputs = self.decoder(tgt, src_outputs, src_mask, tgt_mask)
output = self.out(tgt_outputs)
return output
if __name__ == '__main__':
config = Config(vocab_size=10, seq_len=25)
model = Transformer(config)
input_tensor = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 0, 0]]).long()
input_mask = (input_tensor != 0).unsqueeze(dim=2)
tgt_tensor = torch.tensor(([[7, 8, 9, 0]])).long()
tgt_mask = (tgt_tensor != 0).unsqueeze(dim=2)
model.forward(input_tensor, tgt_tensor, input_mask, tgt_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment