Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created December 4, 2023 19:58
Show Gist options
  • Save tchaton/7c21adc4e0441e8639dcbbadcf027065 to your computer and use it in GitHub Desktop.
Save tchaton/7c21adc4e0441e8639dcbbadcf027065 to your computer and use it in GitHub Desktop.
import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from torch.utils.data import DataLoader, random_split
import os
from time import sleep
class LanguageModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = Transformer(vocab_size=33278)
def prepare_data(self):
dataset = WikiText2()
def setup(self, *args, **kwargs):
dataset = WikiText2()
n = len(dataset)
self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset, [n - 4000, 2000, 2000])
def training_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("val_loss", loss, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("test_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=20, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=20, shuffle=True)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=20, shuffle=True)
def main():
L.seed_everything(42)
# Model
model = LanguageModel()
# Trainer
trainer = L.Trainer(gradient_clip_val=0.25, max_epochs=8)
trainer.fit(model)
trainer.test(model)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment