Skip to content

Instantly share code, notes, and snippets.

@Butanium
Last active February 1, 2024 12:58
Show Gist options
  • Save Butanium/4971ed780702290a69cc85751c3b404e to your computer and use it in GitHub Desktop.
Save Butanium/4971ed780702290a69cc85751c3b404e to your computer and use it in GitHub Desktop.
femtoGPT
"""
A minimal pytorch implementation of a multi-head attention transformer inspired by nanoGPT
"""
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
class SelfAttention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.QKV = nn.Linear(cfg.emb_dim, 3 * cfg.num_heads * cfg.hidden_dim)
self.emb_dim = cfg.num_heads * cfg.hidden_dim
self.proj = nn.Linear(self.emb_dim, cfg.emb_dim)
self.num_heads = cfg.num_heads
self.hidden_dim = cfg.hidden_dim
self.register_buffer("mask", th.ones((10, 10)).tril().view((1, 1, 10, 10)) == 0)
def forward(self, x):
B, S, _ = x.shape
# x: (B, seq, dim)
# QKV: (B, S, 3*hidden*num_heads)
# q: (B, seq, hidden*num_heads)
qkv = self.QKV(x)
# On last dimension, split Q K and V
q, k, v = qkv.split(self.emb_dim, dim=2)
# Add the head dimension to have (B, num_heads, S, emb_dim)
# This is useful to compute the softmax for each head
q = q.view((B, S, self.num_heads, self.hidden_dim)).transpose(1, 2)
k = k.view((B, S, self.num_heads, self.hidden_dim)).transpose(1, 2)
v = v.view((B, S, self.num_heads, self.hidden_dim)).transpose(1, 2)
# dim: (B, num_heads, S, S)
log_att = th.einsum('bhsd, bhSd -> bhsS', q, k) / np.sqrt(self.hidden_dim)
log_att.masked_fill_(self.mask, -th.inf)
attention = F.softmax(log_att, 3)
z = th.einsum('bhsS, bhSd -> bhsd', attention, v)
z = z.transpose(1, 2).reshape((B, S, self.emb_dim))
out = self.proj(z)
return out
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.emb = nn.Linear(cfg.emb_dim, 4 * cfg.emb_dim)
self.act = nn.GELU()
self.proj = nn.Linear(4 * cfg.emb_dim, cfg.emb_dim)
def forward(self, x):
return self.proj(self.act(self.emb(x)))
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.attention = SelfAttention(cfg)
self.att_norm = nn.LayerNorm(cfg.emb_dim)
self.feed_forward = FeedForward(cfg)
self.ff_norm = nn.LayerNorm(cfg.emb_dim)
def forward(self, x):
z = self.attention(x)
x = self.att_norm(z + x)
z = self.feed_forward(x)
out = self.ff_norm(z + x)
return out
class Transformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.wemb = nn.Embedding(cfg.d_vocab, cfg.emb_dim)
self.wpos = nn.Embedding(10, cfg.emb_dim)
self.transformer = nn.Sequential(
nn.LayerNorm(cfg.emb_dim),
*[TransformerBlock(cfg) for _ in range(cfg.nb_layer)]
)
self.out = nn.Linear(cfg.emb_dim, cfg.d_vocab)
def forward(self, x):
device = x.device
B, T = x.shape
pos = self.wpos(th.arange(T, dtype=th.long, device=device))
x = self.wemb(x) + pos.unsqueeze(0)
x = self.transformer(x)
logits = self.out(x)
return logits
from dataclasses import dataclass
@dataclass
class Config:
num_heads: int = 4
hidden_dim: int = 8
emb_dim: int = 32
d_vocab: int = 100
nb_layer: int = 2
model = Transformer(Config())
print(model(th.arange(0, 10).unsqueeze(0)).shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment