Created
January 5, 2022 03:24
-
-
Save SivilTaram/d40bae6e31422b18a11c1610131365a5 to your computer and use it in GitHub Desktop.
Minimal Transformer Demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
class Config: | |
embed_drop = 0.1 | |
residual_drop = 0.1 | |
attn_drop = 0.1 | |
n_layer = 3 | |
n_head = 8 | |
n_embed = 128 | |
def __init__(self, vocab_size, seq_len, **kwargs): | |
self.vocab_size = vocab_size | |
self.seq_len = seq_len | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
class PositionalEmbedding(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
positional_embedding = torch.zeros(config.seq_len, config.n_embed) | |
for pos in range(config.seq_len): | |
for i in range(0, config.n_embed, 2): | |
positional_embedding[pos, i] = math.sin(pos / (10000 ** ((2 * i) / config.n_embed))) | |
positional_embedding[pos, i] = math.cos(pos / (10000 ** ((2 * i + 2) / config.n_embed))) | |
positional_embedding = positional_embedding.unsqueeze(0) | |
self.register_buffer('positional_embedding', positional_embedding) | |
def forward(self, x): | |
x = x * math.sqrt(self.config.n_embed) | |
seq_len = x.size(1) | |
x = x + self.positional_embedding[:, :seq_len] | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.d_k = config.n_embed // config.n_head | |
self.d_model = config.n_embed | |
self.h = config.n_head | |
self.q_linear = nn.Linear(self.d_model, self.d_model) | |
self.k_linear = nn.Linear(self.d_model, self.d_model) | |
self.v_linear = nn.Linear(self.d_model, self.d_model) | |
self.attn_dropout = nn.Dropout(config.attn_drop) | |
self.residual_dropout = nn.Dropout(config.residual_drop) | |
self.out_linear = nn.Linear(self.d_model, self.d_model) | |
mask = torch.tril(torch.ones(config.seq_len, config.seq_len).long() | |
.view(1, 1, config.seq_len, config.seq_len)) | |
self.register_buffer("mask", mask) | |
def forward(self, q, k, v, mask, is_decoder=False): | |
# (batch_size, seq_len, hidden_size) | |
batch_size, seq_len, hidden_size = q.size() | |
# perform linear operation and split into h heads | |
# (batch_size, seq_len, h, d_k) -> (batch_size, h, seq_len, d_k) | |
k = self.k_linear(k).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) | |
q = self.q_linear(q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) | |
v = self.v_linear(v).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) | |
# (batch_size, h, seq_len, d_k) x (batch_size, h, d_k, seq_len) -> (batch_size, h, seq_len, seq_len) | |
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_k)) | |
if is_decoder: | |
mask = mask.long() & self.mask[:, :, :seq_len, :seq_len] | |
attn = attn.masked_fill(mask == 0, float("-inf")) | |
attn = F.softmax(attn, dim=-1) | |
attn = self.attn_dropout(attn) | |
# (batch_size, h, seq_len, seq_len) x (batch_size, h, seq_len, d_k) -> (batch_size, h, seq_len, d_k) | |
output = attn @ v | |
# (batch_size, seq_len, h, d_k) -> (batch_size, seq_len, hidden_size) | |
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) | |
output = self.out_linear(output) | |
return output | |
class PointWiseFeedForward(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.d_model = config.n_embed | |
self.d_ff = config.n_embed * 4 | |
self.dropout = nn.Dropout(config.residual_drop) | |
self.linear_1 = nn.Linear(self.d_model, self.d_ff) | |
self.linear_2 = nn.Linear(self.d_ff, self.d_model) | |
def forward(self, x): | |
x = self.dropout(F.relu(self.linear_1(x))) | |
x = self.linear_2(x) | |
return x | |
class LayerNorm(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.size = config.n_embed | |
self.alpha = nn.Parameter(torch.ones(self.size)) | |
self.bias = nn.Parameter(torch.zeros(self.size)) | |
self.eps = 1e-6 | |
def forward(self, x): | |
norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias | |
return norm | |
class EncoderLayer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.norm_1 = LayerNorm(config) | |
self.norm_2 = LayerNorm(config) | |
self.attn = MultiHeadAttention(config) | |
self.ff = PointWiseFeedForward(config) | |
self.dropout = nn.Dropout(config.residual_drop) | |
def forward(self, x, mask): | |
x = self.norm_1(x + self.dropout(self.attn(x, x, x, mask))) | |
x = self.norm_2(x + self.dropout(self.ff(x))) | |
return x | |
class DecoderLayer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.norm_1 = LayerNorm(config) | |
self.norm_2 = LayerNorm(config) | |
self.norm_3 = LayerNorm(config) | |
self.attn = MultiHeadAttention(config) | |
self.cross_attn = MultiHeadAttention(config) | |
self.ff = PointWiseFeedForward(config) | |
self.dropout = nn.Dropout(config.residual_drop) | |
def forward(self, x, src_outputs, src_mask, tgt_mask): | |
x = self.norm_1(x + self.dropout(self.attn(x, x, x, tgt_mask, is_decoder=True))) | |
x = self.norm_2(x + self.dropout(self.cross_attn(x, src_outputs, src_outputs, src_mask))) | |
x = self.norm_3(x + self.dropout(self.ff(x))) | |
return x | |
class Encoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.n_layer = config.n_layer | |
self.word_embedding = nn.Embedding(config.vocab_size, config.n_embed) | |
self.positional_embedding = PositionalEmbedding(config) | |
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.n_layer)]) | |
self.norm = LayerNorm(config) | |
def forward(self, x, mask): | |
x = self.word_embedding(x) | |
x = self.positional_embedding(x) | |
for i in range(self.n_layer): | |
x = self.layers[i](x, mask) | |
return self.norm(x) | |
class Decoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.n_layer = config.n_layer | |
self.word_embedding = nn.Embedding(config.vocab_size, config.n_embed) | |
self.positional_embedding = PositionalEmbedding(config) | |
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layer)]) | |
self.norm = LayerNorm(config) | |
def forward(self, x, src_outputs, src_mask, tgt_mask): | |
x = self.positional_embedding(self.word_embedding(x)) | |
for i in range(self.n_layer): | |
x = self.layers[i](x, src_outputs, src_mask, tgt_mask) | |
return self.norm(x) | |
class Transformer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.encoder = Encoder(config) | |
self.decoder = Decoder(config) | |
self.out = nn.Linear(config.n_embed, config.vocab_size) | |
def forward(self, src, tgt, src_mask, tgt_mask): | |
src_outputs = self.encoder(src, src_mask) | |
tgt_outputs = self.decoder(tgt, src_outputs, src_mask, tgt_mask) | |
output = self.out(tgt_outputs) | |
return output | |
if __name__ == '__main__': | |
config = Config(vocab_size=10, seq_len=25) | |
model = Transformer(config) | |
input_tensor = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 0, 0]]).long() | |
input_mask = (input_tensor != 0).unsqueeze(dim=2) | |
tgt_tensor = torch.tensor(([[7, 8, 9, 0]])).long() | |
tgt_mask = (tgt_tensor != 0).unsqueeze(dim=2) | |
model.forward(input_tensor, tgt_tensor, input_mask, tgt_mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment