Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active September 22, 2021 18:52
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save wassname/fe1d8940344a9fe8cd1a85f6660e7b1d to your computer and use it in GitHub Desktop.
Transformer in ~80 lines of code from Thomas Wolf's tweet https://twitter.com/Thom_Wolf/status/1129658539142766592
"""
Transformer in ~80 lines of code.
From Thomas Wolf's tweet https://twitter.com/Thom_Wolf/status/1129658539142766592.
"""
import torch
from torch import nn
class Transformer(nn.Module):
"""
Transformer (GPT-2 architecture).
Args:
embed_dim: Dimensionality of the embeddings.
hidden_dim: Dimensionality of the hidden states.
num_embed: Vocabulary size of `x`.
num_pos: Number of positional embeddings.
num_heads: Number of attention heads for each attention layer in the Transformer encoder.
num_layers: Number of hidden layers in the Transformer encoder.
dropout: The dropout probabilitiy for all layers
"""
def __init__(self, embed_dim=768, hidden_dim=768, num_embed=4992, num_pos=768, num_heads=6, num_layers=6, dropout=0.1):
super().__init__()
self.token_embeddings = nn.Embedding(num_embed, embed_dim)
self.poition_embeddings = nn.Embedding(num_pos, embed_dim)
self.dropout = nn.Dropout(dropout)
self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
self.ln_1, self.ln_2 = nn.ModuleList(), nn.ModuleList()
for _ in range(num_layers):
self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, embed_dim)))
self.ln_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
self.ln_2.append(nn.LayerNorm(embed_dim, eps=1e-12))
self.head = nn.Linear(hidden_dim, num_embed)
def forward(self, x):
positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
h = self.token_embeddings(x)
h = h + self.poition_embeddings(positions).expand_as(h)
h = self.dropout(h)
attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.ln_1, self.attentions,
self.ln_2, self.feed_forwards):
h = layer_norm_1(h)
x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False)
x = self.dropout(x)
h = x + h
h = layer_norm_2(h)
x = feed_forward(h)
x = self.dropout(x)
h = x + h
return self.head(h)
# test
transformer = Transformer()
batch_size = 4
seq_len = 32
x = torch.randint(low=0, high=transformer.token_embeddings.num_embeddings, size=(batch_size, seq_len))
y=transformer(x)
y.shape # torch.Size([4, 32, 4992])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment