Skip to content

Instantly share code, notes, and snippets.

@nathanmargaglio
Created June 14, 2024 14:45
Show Gist options
  • Save nathanmargaglio/2b36bad77b455915a52c1cec8265e67e to your computer and use it in GitHub Desktop.
Save nathanmargaglio/2b36bad77b455915a52c1cec8265e67e to your computer and use it in GitHub Desktop.
Simple Transformer Demo
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from pprint import pprint\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[['the', 'cat', 'sat', 'on', 'the', 'mat', '<|end|>'],\n",
" ['the', 'dog', 'barked', 'at', 'the', 'mailman', '<|end|>'],\n",
" ['a', 'bird', 'flew', 'over', 'the', 'fence', '<|end|>'],\n",
" ['the', 'dog', 'sat', 'by', 'the', 'man', '<|end|>']]\n",
"{'<|end|>': 0, 'the': 1, 'cat': 2, 'sat': 3, 'on': 4, 'mat': 5, 'dog': 6, 'barked': 7, 'at': 8, 'mailman': 9, 'a': 10, 'bird': 11, 'flew': 12, 'over': 13, 'fence': 14, 'by': 15, 'man': 16}\n"
]
}
],
"source": [
"# Define a simple dataset with subsequences\n",
"class SimpleDataset(Dataset):\n",
" def __init__(self, data, vocab):\n",
" self.data = []\n",
" self.targets = []\n",
" for sentence in data:\n",
" tokenized_sentence = [vocab[word] for word in sentence]\n",
" for i in range(1, len(tokenized_sentence)):\n",
" self.data.append(tokenized_sentence[:i])\n",
" self.targets.append(tokenized_sentence[i])\n",
" \n",
" def __len__(self):\n",
" return len(self.data)\n",
" \n",
" def __getitem__(self, idx):\n",
" return torch.tensor(self.data[idx], dtype=torch.long), torch.tensor(self.targets[idx], dtype=torch.long)\n",
"\n",
"examples = \"\"\"\n",
"The cat sat on the mat.\n",
"The dog barked at the mailman.\n",
"A bird flew over the fence.\n",
"The dog sat by the man.\n",
"\"\"\"\n",
"\n",
"data = [sentence.split() + [\"<|end|>\"] for sentence in examples.lower().replace(\".\", \"\").replace(\",\", \"\").strip().split(\"\\n\")]\n",
"pprint(data)\n",
"\n",
"vocab = {\n",
" \"<|end|>\" : 0\n",
"}\n",
"for sentence in data:\n",
" for word in sentence:\n",
" if word not in vocab:\n",
" vocab[word] = len(vocab)\n",
"\n",
"# Vocabulary and tokenization\n",
"print(vocab)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Define a simple transformer model using PyTorch modules\n",
"class SimpleModuleTransformer(nn.Module):\n",
" def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward):\n",
" super(SimpleModuleTransformer, self).__init__()\n",
" self.embedding = nn.Embedding(vocab_size, d_model)\n",
" self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_encoder_layers, dim_feedforward)\n",
" self.fc = nn.Linear(d_model, vocab_size)\n",
" \n",
" def forward(self, src):\n",
" src = self.embedding(src)\n",
" src = src.permute(1, 0, 2)\n",
" output = self.transformer(src, src)\n",
" output = self.fc(output[-1])\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Define a simple transformer model using custom layers\n",
"\n",
"# Custom embedding layer\n",
"class CustomEmbedding(nn.Module):\n",
" def __init__(self, vocab_size, d_model):\n",
" super(CustomEmbedding, self).__init__()\n",
" self.embedding = nn.Parameter(torch.randn(vocab_size, d_model))\n",
" \n",
" def forward(self, x):\n",
" return self.embedding[x]\n",
"\n",
"# Custom multi-head self-attention layer\n",
"class MultiHeadSelfAttention(nn.Module):\n",
" def __init__(self, d_model, nhead):\n",
" super(MultiHeadSelfAttention, self).__init__()\n",
" assert d_model % nhead == 0\n",
" self.d_k = d_model // nhead\n",
" self.nhead = nhead\n",
" self.q_linear = nn.Linear(d_model, d_model)\n",
" self.k_linear = nn.Linear(d_model, d_model)\n",
" self.v_linear = nn.Linear(d_model, d_model)\n",
" self.out_linear = nn.Linear(d_model, d_model)\n",
" \n",
" def forward(self, x):\n",
" batch_size, seq_len, d_model = x.size()\n",
" \n",
" # Linear projections\n",
" q = self.q_linear(x)\n",
" k = self.k_linear(x)\n",
" v = self.v_linear(x)\n",
" \n",
" # Split into multiple heads\n",
" q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)\n",
" k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)\n",
" v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)\n",
" \n",
" # Scaled dot-product attention\n",
" scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)\n",
" attn = torch.nn.functional.softmax(scores, dim=-1)\n",
" context = torch.matmul(attn, v)\n",
" \n",
" # Concatenate heads\n",
" context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)\n",
" \n",
" # Final linear layer\n",
" output = self.out_linear(context)\n",
" return output\n",
"\n",
"# Custom feed-forward neural network layer\n",
"class FeedForwardNN(nn.Module):\n",
" def __init__(self, d_model, dim_feedforward):\n",
" super(FeedForwardNN, self).__init__()\n",
" self.fc1 = nn.Linear(d_model, dim_feedforward)\n",
" self.fc2 = nn.Linear(dim_feedforward, d_model)\n",
" \n",
" def forward(self, x):\n",
" x = torch.nn.functional.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"# Custom transformer layer\n",
"class TransformerLayer(nn.Module):\n",
" def __init__(self, d_model, nhead, dim_feedforward):\n",
" super(TransformerLayer, self).__init__()\n",
" self.self_attn = MultiHeadSelfAttention(d_model, nhead)\n",
" self.ffn = FeedForwardNN(d_model, dim_feedforward)\n",
" self.norm1 = nn.LayerNorm(d_model)\n",
" self.norm2 = nn.LayerNorm(d_model)\n",
" \n",
" def forward(self, x):\n",
" x = self.norm1(x + self.self_attn(x))\n",
" x = self.norm2(x + self.ffn(x))\n",
" return x\n",
"\n",
"# Custom transformer model\n",
"class SimpleTransformer(nn.Module):\n",
" def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward):\n",
" super(SimpleTransformer, self).__init__()\n",
" self.embedding = CustomEmbedding(vocab_size, d_model)\n",
" self.layers = nn.ModuleList([TransformerLayer(d_model, nhead, dim_feedforward) for _ in range(num_encoder_layers)])\n",
" self.fc = nn.Linear(d_model, vocab_size)\n",
" \n",
" def forward(self, src):\n",
" src = self.embedding(src)\n",
" for layer in self.layers:\n",
" src = layer(src)\n",
" output = self.fc(src[:, -1, :])\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Padding function for collate_fn\n",
"def collate_fn(batch):\n",
" data, targets = zip(*batch)\n",
" data_lengths = [len(seq) for seq in data]\n",
" max_length = max(data_lengths)\n",
" \n",
" padded_data = torch.zeros(len(data), max_length, dtype=torch.long)\n",
" for i, seq in enumerate(data):\n",
" padded_data[i, :len(seq)] = seq\n",
"\n",
" targets = torch.stack(targets)\n",
" return padded_data, targets\n",
"\n",
"# Hyperparameters\n",
"vocab_size = len(vocab)\n",
"d_model = 16\n",
"nhead = 2\n",
"num_encoder_layers = 2\n",
"dim_feedforward = 32\n",
"\n",
"# Create dataset and dataloader\n",
"dataset = SimpleDataset(data, vocab)\n",
"dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)\n",
"\n",
"# Initialize model, loss function, and optimizer\n",
"use_module = False\n",
"if use_module:\n",
" model = SimpleModuleTransformer(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward)\n",
"else:\n",
" model = SimpleTransformer(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"# Training loop\n",
"for epoch in range(100):\n",
" for src, tgt in dataloader:\n",
" optimizer.zero_grad()\n",
" output = model(src)\n",
" loss = criterion(output, tgt)\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"barked\t\t: 3.6009\n",
"sat\t\t: 3.5092\n",
"by\t\t: 0.9870\n",
"man\t\t: 0.2676\n",
"on\t\t: 0.2619\n",
"at\t\t: 0.0934\n",
"<|end|>\t\t: 0.0068\n",
"bird\t\t: -0.0402\n",
"a\t\t: -0.7484\n",
"the\t\t: -0.7982\n",
"Predicted next word: barked\n"
]
}
],
"source": [
"# Test the model by predicting the next word\n",
"test_sentence = \"the dog\"\n",
"test_tokenized = torch.tensor([[vocab[word] for word in test_sentence.split()]], dtype=torch.long)\n",
"with torch.no_grad():\n",
" output = model(test_tokenized)\n",
" \n",
" # print top 10 words with highest probability\n",
" sorted_output = torch.argsort(output, descending=True)[0][:10]\n",
" for idx in sorted_output:\n",
" word = [word for word, i in vocab.items() if i == idx.item()][0]\n",
" buffer = len(word) < 8 and \"\\t\\t\" or \"\\t\"\n",
" print(f\"{word}{buffer}: {output[0, idx].item():.4f}\")\n",
" \n",
" predicted_word = torch.argmax(output, dim=1).item()\n",
" for word, idx in vocab.items():\n",
" if idx == predicted_word:\n",
" print(f\"Predicted next word: {word}\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated sentence: the dog sat by the man\n"
]
}
],
"source": [
"# Test the model by generating a sentence\n",
"seed_sentence = \"the dog sat\"\n",
"generated_sentence = seed_sentence.split()[:]\n",
"with torch.no_grad():\n",
" for _ in range(100):\n",
" test_tokenized = torch.tensor([[vocab[word] for word in generated_sentence]], dtype=torch.long)\n",
" output = model(test_tokenized)\n",
" predicted_word = torch.argmax(output, dim=1).item()\n",
" for word, idx in vocab.items():\n",
" if idx == predicted_word:\n",
" generated_sentence.append(word)\n",
" break\n",
" if word == \"<|end|>\":\n",
" generated_sentence.pop()\n",
" break\n",
"\n",
"print(\"Generated sentence:\", \" \".join(generated_sentence))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment