Created
June 14, 2024 14:45
-
-
Save nathanmargaglio/2b36bad77b455915a52c1cec8265e67e to your computer and use it in GitHub Desktop.
Simple Transformer Demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"from pprint import pprint\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"from torch.utils.data import DataLoader, Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[['the', 'cat', 'sat', 'on', 'the', 'mat', '<|end|>'],\n", | |
" ['the', 'dog', 'barked', 'at', 'the', 'mailman', '<|end|>'],\n", | |
" ['a', 'bird', 'flew', 'over', 'the', 'fence', '<|end|>'],\n", | |
" ['the', 'dog', 'sat', 'by', 'the', 'man', '<|end|>']]\n", | |
"{'<|end|>': 0, 'the': 1, 'cat': 2, 'sat': 3, 'on': 4, 'mat': 5, 'dog': 6, 'barked': 7, 'at': 8, 'mailman': 9, 'a': 10, 'bird': 11, 'flew': 12, 'over': 13, 'fence': 14, 'by': 15, 'man': 16}\n" | |
] | |
} | |
], | |
"source": [ | |
"# Define a simple dataset with subsequences\n", | |
"class SimpleDataset(Dataset):\n", | |
" def __init__(self, data, vocab):\n", | |
" self.data = []\n", | |
" self.targets = []\n", | |
" for sentence in data:\n", | |
" tokenized_sentence = [vocab[word] for word in sentence]\n", | |
" for i in range(1, len(tokenized_sentence)):\n", | |
" self.data.append(tokenized_sentence[:i])\n", | |
" self.targets.append(tokenized_sentence[i])\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.data)\n", | |
" \n", | |
" def __getitem__(self, idx):\n", | |
" return torch.tensor(self.data[idx], dtype=torch.long), torch.tensor(self.targets[idx], dtype=torch.long)\n", | |
"\n", | |
"examples = \"\"\"\n", | |
"The cat sat on the mat.\n", | |
"The dog barked at the mailman.\n", | |
"A bird flew over the fence.\n", | |
"The dog sat by the man.\n", | |
"\"\"\"\n", | |
"\n", | |
"data = [sentence.split() + [\"<|end|>\"] for sentence in examples.lower().replace(\".\", \"\").replace(\",\", \"\").strip().split(\"\\n\")]\n", | |
"pprint(data)\n", | |
"\n", | |
"vocab = {\n", | |
" \"<|end|>\" : 0\n", | |
"}\n", | |
"for sentence in data:\n", | |
" for word in sentence:\n", | |
" if word not in vocab:\n", | |
" vocab[word] = len(vocab)\n", | |
"\n", | |
"# Vocabulary and tokenization\n", | |
"print(vocab)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Define a simple transformer model using PyTorch modules\n", | |
"class SimpleModuleTransformer(nn.Module):\n", | |
" def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward):\n", | |
" super(SimpleModuleTransformer, self).__init__()\n", | |
" self.embedding = nn.Embedding(vocab_size, d_model)\n", | |
" self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_encoder_layers, dim_feedforward)\n", | |
" self.fc = nn.Linear(d_model, vocab_size)\n", | |
" \n", | |
" def forward(self, src):\n", | |
" src = self.embedding(src)\n", | |
" src = src.permute(1, 0, 2)\n", | |
" output = self.transformer(src, src)\n", | |
" output = self.fc(output[-1])\n", | |
" return output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Define a simple transformer model using custom layers\n", | |
"\n", | |
"# Custom embedding layer\n", | |
"class CustomEmbedding(nn.Module):\n", | |
" def __init__(self, vocab_size, d_model):\n", | |
" super(CustomEmbedding, self).__init__()\n", | |
" self.embedding = nn.Parameter(torch.randn(vocab_size, d_model))\n", | |
" \n", | |
" def forward(self, x):\n", | |
" return self.embedding[x]\n", | |
"\n", | |
"# Custom multi-head self-attention layer\n", | |
"class MultiHeadSelfAttention(nn.Module):\n", | |
" def __init__(self, d_model, nhead):\n", | |
" super(MultiHeadSelfAttention, self).__init__()\n", | |
" assert d_model % nhead == 0\n", | |
" self.d_k = d_model // nhead\n", | |
" self.nhead = nhead\n", | |
" self.q_linear = nn.Linear(d_model, d_model)\n", | |
" self.k_linear = nn.Linear(d_model, d_model)\n", | |
" self.v_linear = nn.Linear(d_model, d_model)\n", | |
" self.out_linear = nn.Linear(d_model, d_model)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" batch_size, seq_len, d_model = x.size()\n", | |
" \n", | |
" # Linear projections\n", | |
" q = self.q_linear(x)\n", | |
" k = self.k_linear(x)\n", | |
" v = self.v_linear(x)\n", | |
" \n", | |
" # Split into multiple heads\n", | |
" q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)\n", | |
" k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)\n", | |
" v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)\n", | |
" \n", | |
" # Scaled dot-product attention\n", | |
" scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)\n", | |
" attn = torch.nn.functional.softmax(scores, dim=-1)\n", | |
" context = torch.matmul(attn, v)\n", | |
" \n", | |
" # Concatenate heads\n", | |
" context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)\n", | |
" \n", | |
" # Final linear layer\n", | |
" output = self.out_linear(context)\n", | |
" return output\n", | |
"\n", | |
"# Custom feed-forward neural network layer\n", | |
"class FeedForwardNN(nn.Module):\n", | |
" def __init__(self, d_model, dim_feedforward):\n", | |
" super(FeedForwardNN, self).__init__()\n", | |
" self.fc1 = nn.Linear(d_model, dim_feedforward)\n", | |
" self.fc2 = nn.Linear(dim_feedforward, d_model)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = torch.nn.functional.relu(self.fc1(x))\n", | |
" x = self.fc2(x)\n", | |
" return x\n", | |
"\n", | |
"# Custom transformer layer\n", | |
"class TransformerLayer(nn.Module):\n", | |
" def __init__(self, d_model, nhead, dim_feedforward):\n", | |
" super(TransformerLayer, self).__init__()\n", | |
" self.self_attn = MultiHeadSelfAttention(d_model, nhead)\n", | |
" self.ffn = FeedForwardNN(d_model, dim_feedforward)\n", | |
" self.norm1 = nn.LayerNorm(d_model)\n", | |
" self.norm2 = nn.LayerNorm(d_model)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = self.norm1(x + self.self_attn(x))\n", | |
" x = self.norm2(x + self.ffn(x))\n", | |
" return x\n", | |
"\n", | |
"# Custom transformer model\n", | |
"class SimpleTransformer(nn.Module):\n", | |
" def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward):\n", | |
" super(SimpleTransformer, self).__init__()\n", | |
" self.embedding = CustomEmbedding(vocab_size, d_model)\n", | |
" self.layers = nn.ModuleList([TransformerLayer(d_model, nhead, dim_feedforward) for _ in range(num_encoder_layers)])\n", | |
" self.fc = nn.Linear(d_model, vocab_size)\n", | |
" \n", | |
" def forward(self, src):\n", | |
" src = self.embedding(src)\n", | |
" for layer in self.layers:\n", | |
" src = layer(src)\n", | |
" output = self.fc(src[:, -1, :])\n", | |
" return output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Padding function for collate_fn\n", | |
"def collate_fn(batch):\n", | |
" data, targets = zip(*batch)\n", | |
" data_lengths = [len(seq) for seq in data]\n", | |
" max_length = max(data_lengths)\n", | |
" \n", | |
" padded_data = torch.zeros(len(data), max_length, dtype=torch.long)\n", | |
" for i, seq in enumerate(data):\n", | |
" padded_data[i, :len(seq)] = seq\n", | |
"\n", | |
" targets = torch.stack(targets)\n", | |
" return padded_data, targets\n", | |
"\n", | |
"# Hyperparameters\n", | |
"vocab_size = len(vocab)\n", | |
"d_model = 16\n", | |
"nhead = 2\n", | |
"num_encoder_layers = 2\n", | |
"dim_feedforward = 32\n", | |
"\n", | |
"# Create dataset and dataloader\n", | |
"dataset = SimpleDataset(data, vocab)\n", | |
"dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)\n", | |
"\n", | |
"# Initialize model, loss function, and optimizer\n", | |
"use_module = False\n", | |
"if use_module:\n", | |
" model = SimpleModuleTransformer(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward)\n", | |
"else:\n", | |
" model = SimpleTransformer(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward)\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n", | |
"\n", | |
"# Training loop\n", | |
"for epoch in range(100):\n", | |
" for src, tgt in dataloader:\n", | |
" optimizer.zero_grad()\n", | |
" output = model(src)\n", | |
" loss = criterion(output, tgt)\n", | |
" loss.backward()\n", | |
" optimizer.step()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"barked\t\t: 3.6009\n", | |
"sat\t\t: 3.5092\n", | |
"by\t\t: 0.9870\n", | |
"man\t\t: 0.2676\n", | |
"on\t\t: 0.2619\n", | |
"at\t\t: 0.0934\n", | |
"<|end|>\t\t: 0.0068\n", | |
"bird\t\t: -0.0402\n", | |
"a\t\t: -0.7484\n", | |
"the\t\t: -0.7982\n", | |
"Predicted next word: barked\n" | |
] | |
} | |
], | |
"source": [ | |
"# Test the model by predicting the next word\n", | |
"test_sentence = \"the dog\"\n", | |
"test_tokenized = torch.tensor([[vocab[word] for word in test_sentence.split()]], dtype=torch.long)\n", | |
"with torch.no_grad():\n", | |
" output = model(test_tokenized)\n", | |
" \n", | |
" # print top 10 words with highest probability\n", | |
" sorted_output = torch.argsort(output, descending=True)[0][:10]\n", | |
" for idx in sorted_output:\n", | |
" word = [word for word, i in vocab.items() if i == idx.item()][0]\n", | |
" buffer = len(word) < 8 and \"\\t\\t\" or \"\\t\"\n", | |
" print(f\"{word}{buffer}: {output[0, idx].item():.4f}\")\n", | |
" \n", | |
" predicted_word = torch.argmax(output, dim=1).item()\n", | |
" for word, idx in vocab.items():\n", | |
" if idx == predicted_word:\n", | |
" print(f\"Predicted next word: {word}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Generated sentence: the dog sat by the man\n" | |
] | |
} | |
], | |
"source": [ | |
"# Test the model by generating a sentence\n", | |
"seed_sentence = \"the dog sat\"\n", | |
"generated_sentence = seed_sentence.split()[:]\n", | |
"with torch.no_grad():\n", | |
" for _ in range(100):\n", | |
" test_tokenized = torch.tensor([[vocab[word] for word in generated_sentence]], dtype=torch.long)\n", | |
" output = model(test_tokenized)\n", | |
" predicted_word = torch.argmax(output, dim=1).item()\n", | |
" for word, idx in vocab.items():\n", | |
" if idx == predicted_word:\n", | |
" generated_sentence.append(word)\n", | |
" break\n", | |
" if word == \"<|end|>\":\n", | |
" generated_sentence.pop()\n", | |
" break\n", | |
"\n", | |
"print(\"Generated sentence:\", \" \".join(generated_sentence))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment