Last active
February 5, 2024 17:43
-
-
Save s-casci/0bad1a671d37d52ada3fb514046103ba to your computer and use it in GitHub Desktop.
Train a GPT to write like Shakespeare
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LinearLR | |
from typing import List, Tuple | |
import subprocess | |
# Download "40,000 lines of Shakespeare from a variety of Shakespeare's plays" | |
subprocess.run( | |
[ | |
"wget", | |
"-nc", | |
"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", | |
], | |
check=True, | |
) | |
class Tokenizer: | |
def __init__(self, text): | |
tokens = list(set(text)) | |
self.chars_to_tokens = {t: i for (i, t) in enumerate(tokens)} | |
self.tokens_to_chars = {i: t for (i, t) in enumerate(tokens)} | |
self.vocab_size = len(tokens) | |
def encode(self, text: str) -> List[str]: | |
return [self.chars_to_tokens[char] for char in text] | |
def decode(self, tokens: List[str]) -> str: | |
return "".join([self.tokens_to_chars[token] for token in tokens]) | |
class CausalDecoderLayer(nn.Module): | |
def __init__( | |
self, model_input_features: int, num_heads: int, feed_forward_dimension: int, activation: str = "gelu" | |
): | |
super().__init__() | |
self.self_attention1 = nn.MultiheadAttention(model_input_features, num_heads, batch_first=True) | |
self.self_attention2 = nn.MultiheadAttention(model_input_features, num_heads, batch_first=True) | |
self.linear1 = nn.Linear(model_input_features, feed_forward_dimension) | |
self.feed_forward_dropout = nn.Dropout(0.1) | |
self.linear2 = nn.Linear(feed_forward_dimension, model_input_features) | |
self.norm1 = nn.LayerNorm(model_input_features, eps=1e-5) | |
self.norm2 = nn.LayerNorm(model_input_features, eps=1e-5) | |
self.norm3 = nn.LayerNorm(model_input_features, eps=1e-5) | |
self.dropout1 = nn.Dropout(0.1) | |
self.dropout2 = nn.Dropout(0.1) | |
self.dropout3 = nn.Dropout(0.1) | |
self.activation = self._get_activation_fn(activation) | |
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
x = self.norm1(x + self._self_attention_block(x, attention_mask, self.self_attention1, self.dropout1)) | |
x = self.norm2(x + self._self_attention_block(x, attention_mask, self.self_attention2, self.dropout2)) | |
x = self.norm3(x + self._feed_forward_block(x)) | |
return x | |
def _self_attention_block( | |
self, x: torch.Tensor, attention_mask: torch.Tensor, attention_layer: nn.Module, dropout_layer: nn.Module | |
) -> torch.Tensor: | |
x = attention_layer(x, x, x, attn_mask=attention_mask, is_causal=True, need_weights=False)[0] | |
x = dropout_layer(x) | |
return x | |
def _feed_forward_block(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.activation(self.linear1(x)) | |
x = self.feed_forward_dropout(x) | |
x = self.linear2(x) | |
x = self.dropout3(x) | |
return x | |
@staticmethod | |
def _get_activation_fn(activation: str) -> torch.Tensor: | |
if activation == "relu": | |
return F.relu | |
elif activation == "gelu": | |
return F.gelu | |
else: | |
raise Exception(f"Invalid activation provided: {activation} (choose between 'relu' and 'gelu')") | |
class CausalTransformer(nn.Module): | |
def __init__( | |
self, | |
vocab_size: int, | |
block_size: int, | |
num_embeddings: int, | |
num_heads: int, | |
num_layers: int, | |
): | |
super().__init__() | |
self.tokens_embeddings = nn.Embedding(vocab_size, num_embeddings) | |
self.positional_embeddings = nn.Embedding(block_size, num_embeddings) | |
model_input_features = 3 * num_embeddings | |
feed_forward_dimension = 4 * num_embeddings | |
self.attention_projection = nn.Linear(num_embeddings, model_input_features) | |
self.decoder_blocks = nn.ModuleList( | |
[ | |
CausalDecoderLayer( | |
model_input_features=model_input_features, | |
num_heads=num_heads, | |
feed_forward_dimension=feed_forward_dimension, | |
activation="gelu", | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.embeddings_projections = nn.Linear(model_input_features, num_embeddings) | |
self.linear_head = nn.Linear(num_embeddings, vocab_size) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.to(self.device) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
batch_block_size = x.shape[-1] | |
tokens_embeddings = self.tokens_embeddings(x) | |
positional_embeddings = self.positional_embeddings( | |
torch.arange(0, batch_block_size, dtype=torch.long, device=self.device).unsqueeze(0) | |
) | |
x = tokens_embeddings + positional_embeddings | |
x = self.attention_projection(x) | |
attention_mask = torch.tril(torch.ones(batch_block_size, batch_block_size)).unsqueeze(0).to(self.device) | |
for decoder_block in self.decoder_blocks: | |
x = decoder_block(x, attention_mask) | |
x = self.embeddings_projections(x) | |
x = self.linear_head(x) | |
return x | |
def main( | |
block_size: int = 32, | |
num_embeddings: int = 64, | |
num_heads: int = 4, | |
num_layers: int = 4, | |
learning_rate: float = 1e-3, | |
num_iterations: int = 5000, | |
batch_size: int = 16, | |
generate_every: int = 500, | |
max_new_tokens: int = 500, | |
): | |
with open("input.txt", "r", encoding="utf-8") as f: | |
text = f.read() | |
tokenizer = Tokenizer(text) | |
model = CausalTransformer( | |
block_size=block_size, | |
vocab_size=tokenizer.vocab_size, | |
num_embeddings=num_embeddings, | |
num_heads=num_heads, | |
num_layers=num_layers, | |
) | |
print(sum(parameter.numel() for parameter in model.parameters()) / 1e6, "M parameters") | |
data = torch.tensor(tokenizer.encode(text), dtype=torch.long) | |
def get_batch(batch_size: int) -> Tuple[torch.tensor, torch.tensor]: | |
indices = torch.randint(len(data) - block_size, (batch_size,)) | |
batch_xs = torch.stack([data[index : index + block_size] for index in indices]).to(model.device) | |
batch_ys = torch.stack([data[index + 1 : index + block_size + 1] for index in indices]).to(model.device) | |
return batch_xs, batch_ys | |
@torch.no_grad | |
def generate(max_new_tokens: int) -> str: | |
tokens = tokenizer.encode("\n") | |
model.eval() | |
for _ in range(max_new_tokens): | |
preds = model( | |
torch.tensor( | |
[tokens[-block_size:]], | |
dtype=torch.long, | |
device=model.device, | |
) | |
) | |
logits = preds[0][-1] | |
probs = F.softmax(logits, -1) | |
next_token = torch.multinomial(probs, num_samples=1).item() | |
tokens.append(next_token) | |
model.train() | |
return tokenizer.decode(tokens) | |
optimizer = AdamW(model.parameters(), lr=learning_rate) | |
scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=num_iterations) | |
model.train() | |
for it in range(1, num_iterations + 1): | |
batch_xs, batch_ys = get_batch(batch_size) | |
preds = model(batch_xs) | |
loss = F.cross_entropy( | |
preds.view(-1, tokenizer.vocab_size), | |
batch_ys.view(-1), | |
) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
print(f"it: {it:4d}, loss: {loss.item():.4f}, lr: {scheduler.get_last_lr()[0]:.6f}") | |
if it % generate_every == 0: | |
print(f"\n\n{generate(max_new_tokens)}\n\n") | |
model.eval() | |
print("\nFinal generation:") | |
print(generate(max_new_tokens * 10)) | |
if __name__ == "__main__": | |
import fire | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment