Skip to content

Instantly share code, notes, and snippets.

@s-casci
Last active February 5, 2024 17:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save s-casci/0bad1a671d37d52ada3fb514046103ba to your computer and use it in GitHub Desktop.
Save s-casci/0bad1a671d37d52ada3fb514046103ba to your computer and use it in GitHub Desktop.
Train a GPT to write like Shakespeare
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from typing import List, Tuple
import subprocess
# Download "40,000 lines of Shakespeare from a variety of Shakespeare's plays"
subprocess.run(
[
"wget",
"-nc",
"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
],
check=True,
)
class Tokenizer:
def __init__(self, text):
tokens = list(set(text))
self.chars_to_tokens = {t: i for (i, t) in enumerate(tokens)}
self.tokens_to_chars = {i: t for (i, t) in enumerate(tokens)}
self.vocab_size = len(tokens)
def encode(self, text: str) -> List[str]:
return [self.chars_to_tokens[char] for char in text]
def decode(self, tokens: List[str]) -> str:
return "".join([self.tokens_to_chars[token] for token in tokens])
class CausalDecoderLayer(nn.Module):
def __init__(
self, model_input_features: int, num_heads: int, feed_forward_dimension: int, activation: str = "gelu"
):
super().__init__()
self.self_attention1 = nn.MultiheadAttention(model_input_features, num_heads, batch_first=True)
self.self_attention2 = nn.MultiheadAttention(model_input_features, num_heads, batch_first=True)
self.linear1 = nn.Linear(model_input_features, feed_forward_dimension)
self.feed_forward_dropout = nn.Dropout(0.1)
self.linear2 = nn.Linear(feed_forward_dimension, model_input_features)
self.norm1 = nn.LayerNorm(model_input_features, eps=1e-5)
self.norm2 = nn.LayerNorm(model_input_features, eps=1e-5)
self.norm3 = nn.LayerNorm(model_input_features, eps=1e-5)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
self.dropout3 = nn.Dropout(0.1)
self.activation = self._get_activation_fn(activation)
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
x = self.norm1(x + self._self_attention_block(x, attention_mask, self.self_attention1, self.dropout1))
x = self.norm2(x + self._self_attention_block(x, attention_mask, self.self_attention2, self.dropout2))
x = self.norm3(x + self._feed_forward_block(x))
return x
def _self_attention_block(
self, x: torch.Tensor, attention_mask: torch.Tensor, attention_layer: nn.Module, dropout_layer: nn.Module
) -> torch.Tensor:
x = attention_layer(x, x, x, attn_mask=attention_mask, is_causal=True, need_weights=False)[0]
x = dropout_layer(x)
return x
def _feed_forward_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.activation(self.linear1(x))
x = self.feed_forward_dropout(x)
x = self.linear2(x)
x = self.dropout3(x)
return x
@staticmethod
def _get_activation_fn(activation: str) -> torch.Tensor:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
else:
raise Exception(f"Invalid activation provided: {activation} (choose between 'relu' and 'gelu')")
class CausalTransformer(nn.Module):
def __init__(
self,
vocab_size: int,
block_size: int,
num_embeddings: int,
num_heads: int,
num_layers: int,
):
super().__init__()
self.tokens_embeddings = nn.Embedding(vocab_size, num_embeddings)
self.positional_embeddings = nn.Embedding(block_size, num_embeddings)
model_input_features = 3 * num_embeddings
feed_forward_dimension = 4 * num_embeddings
self.attention_projection = nn.Linear(num_embeddings, model_input_features)
self.decoder_blocks = nn.ModuleList(
[
CausalDecoderLayer(
model_input_features=model_input_features,
num_heads=num_heads,
feed_forward_dimension=feed_forward_dimension,
activation="gelu",
)
for _ in range(num_layers)
]
)
self.embeddings_projections = nn.Linear(model_input_features, num_embeddings)
self.linear_head = nn.Linear(num_embeddings, vocab_size)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_block_size = x.shape[-1]
tokens_embeddings = self.tokens_embeddings(x)
positional_embeddings = self.positional_embeddings(
torch.arange(0, batch_block_size, dtype=torch.long, device=self.device).unsqueeze(0)
)
x = tokens_embeddings + positional_embeddings
x = self.attention_projection(x)
attention_mask = torch.tril(torch.ones(batch_block_size, batch_block_size)).unsqueeze(0).to(self.device)
for decoder_block in self.decoder_blocks:
x = decoder_block(x, attention_mask)
x = self.embeddings_projections(x)
x = self.linear_head(x)
return x
def main(
block_size: int = 32,
num_embeddings: int = 64,
num_heads: int = 4,
num_layers: int = 4,
learning_rate: float = 1e-3,
num_iterations: int = 5000,
batch_size: int = 16,
generate_every: int = 500,
max_new_tokens: int = 500,
):
with open("input.txt", "r", encoding="utf-8") as f:
text = f.read()
tokenizer = Tokenizer(text)
model = CausalTransformer(
block_size=block_size,
vocab_size=tokenizer.vocab_size,
num_embeddings=num_embeddings,
num_heads=num_heads,
num_layers=num_layers,
)
print(sum(parameter.numel() for parameter in model.parameters()) / 1e6, "M parameters")
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
def get_batch(batch_size: int) -> Tuple[torch.tensor, torch.tensor]:
indices = torch.randint(len(data) - block_size, (batch_size,))
batch_xs = torch.stack([data[index : index + block_size] for index in indices]).to(model.device)
batch_ys = torch.stack([data[index + 1 : index + block_size + 1] for index in indices]).to(model.device)
return batch_xs, batch_ys
@torch.no_grad
def generate(max_new_tokens: int) -> str:
tokens = tokenizer.encode("\n")
model.eval()
for _ in range(max_new_tokens):
preds = model(
torch.tensor(
[tokens[-block_size:]],
dtype=torch.long,
device=model.device,
)
)
logits = preds[0][-1]
probs = F.softmax(logits, -1)
next_token = torch.multinomial(probs, num_samples=1).item()
tokens.append(next_token)
model.train()
return tokenizer.decode(tokens)
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=num_iterations)
model.train()
for it in range(1, num_iterations + 1):
batch_xs, batch_ys = get_batch(batch_size)
preds = model(batch_xs)
loss = F.cross_entropy(
preds.view(-1, tokenizer.vocab_size),
batch_ys.view(-1),
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
print(f"it: {it:4d}, loss: {loss.item():.4f}, lr: {scheduler.get_last_lr()[0]:.6f}")
if it % generate_every == 0:
print(f"\n\n{generate(max_new_tokens)}\n\n")
model.eval()
print("\nFinal generation:")
print(generate(max_new_tokens * 10))
if __name__ == "__main__":
import fire
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment