-
-
Save sntaus/b3a90fd7ca74a9b0bdb95a4866693d52 to your computer and use it in GitHub Desktop.
Building a mini GPT like Language Model from Scratch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import necessary libraries | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import pprint | |
# Function to obtain training data, vocab and mapping from word to index and vice versa | |
def get_data_and_vocab(): | |
# Define training data | |
training_data = { | |
"how are you": "i am fine <end>", | |
"who is john": "a nice person <end>", | |
"who is nice": "john <end>", | |
"where is john": "at home <end>", | |
"how is john": "i dont know <end>", | |
"who are you": "mini gpt model <end>" | |
} | |
# Extract input and target phrases | |
data_words = [k for k, _ in training_data.items()] | |
target_words = [v for _, v in training_data.items()] | |
# Build vocabulary from training data | |
vocabulary_words = list(set([element.lower() for nestedlist in [x.split(" ") for x in data_words] for element in nestedlist] + [element.lower() for nestedlist in [x.split(" ") for x in target_words] for element in nestedlist])) | |
# Ensure <end> token is at the end of vocabulary list, and there's a blank at the beginning | |
vocabulary_words.remove("<end>") | |
vocabulary_words.append("<end>") | |
vocabulary_words.insert(0, "") | |
# Create mappings from word to index and index to word | |
word_to_ix = {vocabulary_words[k].lower(): k for k in range(len(vocabulary_words))} | |
ix_to_word = {v: k for k, v in word_to_ix.items()} | |
# Return all the necessary data and mappings | |
return training_data, data_words, target_words, vocabulary_words, word_to_ix, ix_to_word | |
# Function to convert a batch of sequences of words to a tensor of indices | |
def words_to_tensor(seq_batch, device=None): | |
index_batch = [] | |
# Loop over sequences in the batch | |
for seq in seq_batch: | |
word_list = seq.lower().split(" ") | |
indices = [word_to_ix[word] for word in word_list if word in word_to_ix] | |
t = torch.tensor(indices) | |
if device is not None: | |
t = t.to(device) # Transfer tensor to the specified device | |
index_batch.append(t) | |
# Pad tensors to have the same length | |
return pad_tensors(index_batch) | |
# Function to convert a tensor of indices to a list of sequences of words | |
def tensor_to_words(tensor): | |
index_batch = tensor.cpu().numpy().tolist() | |
res = [] | |
for indices in index_batch: | |
words = [] | |
for ix in indices: | |
words.append(ix_to_word[ix].lower()) # Convert index to word | |
if ix == word_to_ix["<end>"]: | |
break # Stop when <end> token is encountered | |
res.append(" ".join(words)) | |
return res | |
# Function to pad a list of tensors to the same length | |
def pad_tensors(list_of_tensors): | |
tensor_count = len(list_of_tensors) if not torch.is_tensor(list_of_tensors) else list_of_tensors.shape[0] | |
max_dim = max(t.shape[0] for t in list_of_tensors) # Find the maximum length | |
res = [] | |
for t in list_of_tensors: | |
# Create a zero tensor of the desired shape | |
res_t = torch.zeros(max_dim, *t.shape[1:]).type(t.dtype).to(t.device) | |
res_t[:t.shape[0]] = t # Copy the original tensor into the padded tensor | |
res.append(res_t) | |
# Concatenate tensors along a new dimension | |
res = torch.cat(res) | |
firstDim = len(list_of_tensors) | |
secondDim = max_dim | |
# Reshape the result to have the new dimension first | |
return res.reshape(firstDim, secondDim, *res.shape[1:]) | |
# Define Self-Attention module | |
class SelfAttention(nn.Module): | |
def __init__(self, embed_size, head_count): | |
super(SelfAttention, self).__init__() | |
self.embed_size = embed_size # Size of word embeddings | |
self.head_count = head_count # Number of attention heads | |
# Create linear layers for query, key and value projections for each head | |
self.query_layers = nn.ModuleList([nn.Linear(embed_size, embed_size, bias=False) for _ in range(head_count)]) | |
self.key_layers = nn.ModuleList([nn.Linear(embed_size, embed_size, bias=False) for _ in range(head_count)]) | |
self.value_layers = nn.ModuleList([nn.Linear(embed_size, embed_size, bias=False) for _ in range(head_count)]) | |
self.fc_out = nn.Linear(head_count * embed_size, embed_size) # Final linear layer to combine head outputs | |
def forward(self, embeddings): | |
batch_size, token_count = embeddings.shape[:2] | |
qkvs = torch.zeros(self.head_count, 3, batch_size, token_count, self.embed_size).to(embeddings.device) | |
# Loop over heads and compute query, key and value projections | |
for i in range(self.head_count): | |
qkvs[i, 0] = self.query_layers[i](embeddings) | |
qkvs[i, 1] = self.key_layers[i](embeddings) | |
qkvs[i, 2] = self.value_layers[i](embeddings) | |
# Compute energy terms for each head, batch, and pair of tokens | |
energy = torch.zeros(self.head_count, batch_size, token_count, token_count).to(embeddings.device) | |
# Create a mask with false on and below the diagonal, and true above the diagonal | |
mask = torch.triu(torch.ones((token_count, token_count)), diagonal=1).bool() | |
for h in range(self.head_count): | |
for b in range(batch_size): | |
for i in range(token_count): | |
for j in range(token_count): | |
energy[h, b, i, j] = torch.dot(qkvs[h, 0, b, i], qkvs[h, 1, b, j]) | |
energy[h, b] = energy[h, b].masked_fill(mask, float('-inf')) # Apply mask | |
# Compute attention scores | |
attention = torch.nn.functional.softmax(energy, dim=3) | |
# Compute weighted sum of values for each head and token | |
out = torch.zeros(batch_size, token_count, self.head_count, self.embed_size).to(embeddings.device) | |
for h in range(self.head_count): | |
for b in range(batch_size): | |
for i in range(token_count): | |
for j in range(token_count): | |
out[b, i, h] += (attention[h, b, i, j] * qkvs[h, 2, b, j]) | |
# Reshape and pass through final linear layer | |
out = out.reshape(batch_size, token_count, self.head_count * self.embed_size) | |
return self.fc_out(out) | |
def masked_attention(self, energy): | |
# Assume scores has shape (batch_size, max_token_count, embed_size, embed_size) | |
max_token_count, embed_size, _ = energy.size() | |
# Create a mask with zeros on and below the diagonal, and negative infinity above the diagonal | |
mask = torch.triu(torch.ones((max_token_count, max_token_count)), diagonal=1) * float('-inf') | |
mask = mask.unsqueeze(0).unsqueeze(0) # Add dimensions for batch and embedding size | |
mask = mask.expand(batch_size, embed_size, -1, -1) # Expand mask to match batch and embedding size | |
# Apply the mask to the scores | |
masked_scores = energy + mask.to(energy.device) | |
return masked_scores.to(energy.device) | |
# Define Transformer block module | |
class TransformerBlock(nn.Module): | |
def __init__(self, embed_size, head_count): | |
super(TransformerBlock, self).__init__() | |
self.attention = SelfAttention(embed_size, head_count) # Self-attention layer | |
self.norm1 = nn.LayerNorm(embed_size) # Layer normalization | |
self.norm2 = nn.LayerNorm(embed_size) # Layer normalization | |
# Feed-forward neural network | |
self.feed_forward = nn.Sequential( | |
nn.Linear(embed_size, embed_size), | |
nn.ReLU(), | |
nn.Linear(embed_size, embed_size) | |
) | |
def forward(self, embeddings): | |
attention = self.attention(embeddings) | |
# Apply residual connections and layer normalization | |
out = self.norm1(attention + embeddings) | |
out = attention + self.feed_forward(out) | |
out = self.norm2(out) | |
return out | |
# Define Transformer module | |
class Transformer(nn.Module): | |
def __init__(self, vocab_size, embed_size, num_layers, head_count): | |
super(Transformer, self).__init__() | |
self.embed_size = embed_size # Size of word embeddings | |
self.vocab_size = vocab_size # Size of vocabulary | |
self.word_embedding = nn.Embedding(vocab_size, embed_size) # Embedding layer | |
# List of transformer blocks | |
self.layers = nn.ModuleList( | |
[TransformerBlock(embed_size, head_count) for _ in range(num_layers)] | |
) | |
self.fc_out = nn.Linear(embed_size, vocab_size) # Final linear layer to produce logits | |
def forward(self, input_tokens, mask=None): | |
batch_size, token_count = input_tokens.shape[:2] | |
out = self.word_embedding(input_tokens) # Obtain word embeddings | |
# Compute position encodings and add to word embeddings | |
positions = torch.arange(0, token_count).expand(batch_size, token_count).to(input_tokens.device) | |
position_encoding = self.position_encoding(positions, self.embed_size) | |
out += position_encoding.reshape(out.shape) | |
# Pass through each transformer block | |
for layer in self.layers: | |
out = layer(out) | |
# Produce logits for the final token in each sequence | |
out = self.fc_out(out[:, -1, :].reshape(batch_size, self.embed_size)).reshape(batch_size, self.vocab_size) | |
return torch.nn.functional.softmax(out, dim=1) # Apply softmax to obtain probabilities | |
def position_encoding(self, positions, embed_size): | |
# Compute position encoding for each position and dimension | |
angle_rads = self.get_angles( | |
positions.unsqueeze(2).float(), | |
torch.arange(embed_size)[None, None, :].float().to(positions.device), | |
embed_size | |
) | |
sines = torch.sin(angle_rads[:, :, 0::2]) # Compute sine of angle for even dimensions | |
cosines = torch.cos(angle_rads[:, :, 1::2]) # Compute cosine of angle for odd dimensions | |
pos_encoding = torch.cat([sines, cosines], dim=-1) # Concatenate sine and cosine values | |
pos_encoding = pos_encoding[None, ...] | |
return pos_encoding | |
def get_angles(self, pos, i, embed_size): | |
# Compute angle rate for each position and dimension | |
angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / embed_size) | |
return pos * angle_rates | |
# Function to train the model recursively over each sequence and token | |
def train_recursive(model, data, targets, optimizer, criterion): | |
model.train() # Set model to training mode | |
optimizer.zero_grad() # Zero the gradients | |
total_loss = 0 # Initialize total loss | |
batch_size, token_count, token_count_out = data.shape[0], data.shape[1], targets.shape[1] | |
# Loop over sequences in the batch | |
for b in range(batch_size): | |
end_encountered = False | |
cur_count = 0 | |
# Loop over tokens in the sequence | |
while not end_encountered: | |
target_vector = torch.zeros(model.vocab_size).to(data.device) # Initialize target vector | |
if cur_count != token_count_out: | |
expected_next_token_idx = targets[b, cur_count] # Get index of expected next token | |
target_vector[expected_next_token_idx] = 1 # Set the corresponding element of the target vector to 1 | |
# Concatenate current input and output tokens and pass through model | |
if cur_count > 0: | |
model_input = data[b].reshape(token_count).to(data.device) | |
part_of_output = targets[b, :cur_count].to(data.device) | |
model_input = torch.cat((model_input, part_of_output)) | |
else: | |
model_input = data[b] | |
out = model(model_input.reshape(1, token_count + cur_count)) | |
# Compute loss and accumulate total loss | |
loss = criterion(out, target_vector.reshape(out.shape)) | |
total_loss += loss | |
cur_count += 1 | |
# Stop when the end of the sequence is reached | |
if cur_count > token_count_out: | |
end_encountered = True | |
# Backpropagate gradients and update model parameters | |
total_loss.backward() | |
optimizer.step() | |
return total_loss.item() / batch_size | |
# Function to perform inference recursively for each sequence in a batch | |
def infer_recursive(model, input_vectors, max_output_token_count=10): | |
model.eval() # Set model to evaluation mode | |
outputs = [] | |
# Loop over sequences in the batch | |
for i in range(input_vectors.shape[0]): | |
print(f"Infering sequence {i}") | |
input_vector = input_vectors[i].reshape(1, input_vectors.shape[1]) | |
predicted_sequence = [] | |
wc = 0 # Initialize word count | |
with torch.no_grad(): # Disable gradient computation | |
while True: | |
output = model(input_vector) # Pass current input through model | |
predicted_index = output[0, :].argmax().item() # Get index of predicted token | |
predicted_sequence.append(predicted_index) # Append predicted index to sequence | |
# Stop when <end> token is predicted or the maximum output length is reached | |
if predicted_index == word_to_ix['<end>'] or wc > max_output_token_count: | |
break | |
# Append predicted token to input and increment word count | |
input_vector = torch.cat([input_vector, torch.tensor([[predicted_index]])], dim=1) | |
wc += 1 | |
outputs.append(torch.tensor(predicted_sequence)) # Append predicted sequence to outputs | |
outputs = pad_tensors(outputs) # Pad predicted sequences to the same length | |
return outputs | |
# Function to demonstrate training and inference | |
def example_training_and_inference(): | |
# Get model hyperparameters from vocabulary size | |
vocab_size = len(word_to_ix) | |
embed_size = 512 | |
num_layers = 4 | |
heads = 3 | |
# Create model, optimizer, and loss function | |
device = torch.device("cpu") | |
model = Transformer(vocab_size, embed_size, num_layers, heads).to(device) | |
optimizer = optim.Adam(model.parameters(), lr=0.00001) | |
criterion = nn.CrossEntropyLoss() | |
# Convert training data to tensors | |
data = words_to_tensor(data_words, device=device) | |
targets = words_to_tensor(target_words, device=device) | |
# Train model for 55 epochs | |
for epoch in range(55): | |
avg_loss = train_recursive(model, data, targets, optimizer, criterion) | |
print(f'Epoch {epoch + 1}, Loss: {avg_loss:.4f}') | |
# Perform inference on training data | |
input_vector = words_to_tensor(data_words, device=device) | |
predicted_vector = infer_recursive(model, input_vector) | |
predicted_words = tensor_to_words(predicted_vector) | |
# Print training data and model output | |
print("\n\n\n") | |
print("Training Data:") | |
pprint.pprint(training_data) | |
print("\n\n") | |
print("Model Inference:") | |
result_data = {data_words[k]: predicted_words[k] for k in range(len(predicted_words))} | |
pprint.pprint(result_data) | |
# Main function to call the demonstration function | |
if __name__ == "__main__": | |
# Get training data and vocabulary | |
training_data, data_words, target_words, vocabulary_words, word_to_ix, ix_to_word = get_data_and_vocab() | |
# Run the example training and inference function | |
example_training_and_inference() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment