Skip to content

Instantly share code, notes, and snippets.

@vedantroy
Created June 29, 2022 07:18
Show Gist options
  • Save vedantroy/05e6500ae1bc2f6164b87a6510456007 to your computer and use it in GitHub Desktop.
Save vedantroy/05e6500ae1bc2f6164b87a6510456007 to your computer and use it in GitHub Desktop.
Transformer
import torch
import torch.nn as nn
from params import params
# NOTATION:
# W_k = W (k as subscript)
# Wk = W (k as superscript)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, mask: torch.Tensor):
super(self).__init__()
self.mask = mask
self.d_k = d_model / num_heads
assert self.d_k.is_integer()
self.num_heads = num_heads
# TODO: Why no biases?
# Notice, there are no biases:
# MultiHead(Q, K, V) = Concat(head_1, ..., head_h)(Wo)
# where head_i = Attention(QWq, KWk, VWv)
# A confusing part: there should be multiple attention heads
# each with its own copy of Wq, Wk, Wv -- but to represent that
# we'll just use a single giant matrix + Pytorch trickery
self.Wq = nn.Linear(d_model, d_model, bias=False)
self.Wk = nn.Linear(d_model, d_model, bias=False)
self.Wv = nn.Linear(d_model, d_model, bias=False)
self.linear = nn.Linear(d_model, d_model)
def forward(self, x):
orig_shape = x.shape
batch_size, sequence_len, d_model = x.shape
assert params['batch_size'] == batch_size
assert params['sequence_len'] == sequence_len
assert params['d_model'] == d_model
Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x)
expected_shape = (batch_size, sequence_len, d_model)
assert Q.shape == expected_shape
assert K.shape == expected_shape
assert V.shape == expected_shape
Q, K, V = self.split_into_heads(Q), self.split_into_heads(K), self.split_into_heads(V)
expected_shape == (batch_size, self.num_heads, sequence_len, self.d_k)
assert Q.shape == expected_shape
assert K.shape == expected_shape
assert V.shape == expected_shape
K_T = K.transpose(2, 3)
assert K_T.shape == (batch_size, self.num_heads, self.d_k, sequence_len)
# For high-dimensional tensors, the matrix multiplication can only be
# operated on the last two dimensions, which requires the previous dimensions to be equal.
query_attention_to_keys = Q @ K_T
query_attention_to_keys *= (1 / torch.sqrt(self.d_k))
assert query_attention_to_keys.shape == (batch_size, self.num_heads, sequence_len, sequence_len)
assert self.mask.shape == (sequence_len, sequence_len)
# From paper:
# > We need to prevent leftward
# > information flow in the decoder to preserve the auto-regressive property. We implement this
# > inside of scaled dot-product attention by masking out (setting to −∞) all values in the input
query_attention_to_keys.masked_fill_(self.mask == 0, -1e9)
query_attention_to_keys_normalized = torch.softmax(query_attention_to_keys, dim=3)
combined_value_vectors = query_attention_to_keys_normalized @ V
assert combined_value_vectors.shape == (batch_size, self.num_heads, sequence_len, self.d_k)
transposed = combined_value_vectors.transpose(1, 2)
assert transposed.shape == (batch_size, sequence_len, self.num_heads, self.d_k)
concatted = transposed.view(batch_size, sequence_len, d_model)
out = self.linear(concatted)
assert out.shape == orig_shape
return out
def split_into_heads(self, tensor):
batch_size, sequence_len, d_model = tensor.shape
assert self.d_k * self.num_heads == d_model
return tensor.view(batch_size, sequence_len, self.num_heads, self.d_k).transpose(1, 2)
class DecoderLayer(nn.Module):
def __init__(self, num_heads: int, d_model: int, d_ff: int, mask: torch.Tensor):
super(self).__init__()
self.d_ff = d_ff
self.attention = MultiHeadAttention(d_model, num_heads, mask)
# TODO: How was this epsilon chosen?
self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
# https://stats.stackexchange.com/questions/485910/what-is-the-role-of-feed-forward-layer-in-transformer-neural-network-architectur
self.lin1 = nn.Linear(d_model, d_ff)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(d_ff, d_model)
self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, x):
batch_size, sequence_len, d_model = params['batch_size'], params['sequence_len'], params['d_model']
attention = self.attention(x)
assert attention.shape == (batch_size, sequence_len, d_model)
# add & normalize
x = x + attention
x = self.norm1(x)
assert x.shape == (batch_size, sequence_len, d_model)
before_ffn = x
x = self.lin1(x)
assert x.shape == (batch_size, sequence_len, self.d_ff)
x = self.relu(x)
x = self.lin2(x)
assert x.shape == (batch_size, sequence_len, d_model)
# add & normalize again
x = self.norm2(before_ffn + x)
return x
class Transformer(nn.Module):
def __init__(self, vocab_size: int, num_heads: int, d_model: int, sequence_len: int, layers: int, mask: torch.Tensor):
super().__init__()
self.vocab_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
self.positional_embedding = nn.Embedding(num_embeddings=sequence_len, embedding_dim=d_model)
assert self.vocab_embedding.weight.shape == (vocab_size, d_model)
assert self.positional_embedding.weight.shape == (sequence_len, d_model)
self.decoder_layers = nn.ModuleList(DecoderLayer(num_heads, d_model, mask) for _ in range(layers))
# Maps the output embeddings back to tokens
# You could also do:
# torch.matmul(decoder_output, self.vocab_embedding.weight.tranpose(0, 1))
# inside of `forward` if you wanted
# If you wanted to re-use the input embedding matrix
# TODO: How does gradient flow work if we re-use the embedding matrix?
# TODO: Why don't we need to subtract positional encodings if using tied?
# https://github.com/tunz/transformer-pytorch/blob/e7266679f0b32fd99135ea617213f986ceede056/model/transformer.py#L292
self.linear = nn.Linear(d_model, vocab_size)
def forward(self, x):
batch_size, sequence_len = x.shape
assert params['batch_size'] == batch_size
assert params['sequence_len'] == sequence_len
embeddings = self.vocab_embedding(x) + self.positional_embedding(x)
assert embeddings.shape == (batch_size, sequence_len, params['d_model'])
decoder_output = embeddings
for layer in self.decoder_layers:
decoder_output = layer(embeddings)
return self.linear(decoder_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment