Skip to content

Instantly share code, notes, and snippets.

@rasbt
Created January 21, 2024 00:01
Show Gist options
  • Save rasbt/4c32fac33a6641b1fb608718e2a51500 to your computer and use it in GitHub Desktop.
Save rasbt/4c32fac33a6641b1fb608718e2a51500 to your computer and use it in GitHub Desktop.
Transformer PyTorch Pseudo Code
import torch.nn as nn
class Transformer(nn.Module):
def __init__(self, ...):
super().__init__()
...
self.encoder_layers = nn.ModuleList(...)
self.decoder_layers = nn.ModuleList(...)
...
def forward(self, x, y):
...
encoder_result = x
for layer in self.encoder_layers:
encoder_result = layer(encoder_result, x_mask)
decoder_result = y
for layer in self.decoder_layers:
decoder_result = layer(decoder_result, encoder_result, x_mask, y_mask)
output = ...
return output
loss = nn.CrossEntropyLoss(0)
optimizer = optim.AdamW(...)
for batch in data_loader:
optimizer.zero_grad()
output = transformer(x_data, y_data[:, :-1])
loss = loss(...)
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment