Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active January 12, 2024 23:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/e29d02b5026a531e13912cf768e6fdc8 to your computer and use it in GitHub Desktop.
Save wassname/e29d02b5026a531e13912cf768e6fdc8 to your computer and use it in GitHub Desktop.
This is my cheatsheet, for my current best practices for using pytorch lightning `lightning_start.py`. This is verbose so that I can delete what is not needed. I mainly log to csv to keep things simple.
"""
This is a template for starting with pytorch lightning, it includes many extra things because it's easier to delete than reinvent.
It is written for these versions:
- lightning==2.0.2
- pytorch-optimizer==2.8.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import lightning.pytorch as pl
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path
from IPython.display import display, HTML
import warnings
import logging
import os
plt.style.use('ggplot')
torch.set_float32_matmul_precision('medium')
warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings("ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*")
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
os.environ['TQDM_MININTERVAL'] = '5'
class SeqDataSet(torch.utils.data.Dataset):
"""fast windowed access to a dataframe."""
def __init__(self, df: pd.DataFrame, window_past=40, columns_target=['energy(kWh/hh)']):
self.df = df
self.window_past = window_past
self.columns_target = columns_target
# It's 100x faster to work with np
self._x = self.df.values
self._y = self.df[columns_target].values
self._x_cols = self.df.columns
def get_components(self, i):
i = len(self)+i if i < 0 else i
x = self._x[i : i + self.window_past].copy()
y = self._y[i + self.window_past+ 1].copy()
time = self.df.index.values[i + self.window_past+ 1].copy()
return x, y, time
def __len__(self):
return len(self._x) - (self.window_past + 1)
def __repr__(self):
t = self.df.index
return f'<{type(self).__name__}(shape={self.df.shape}, times={t[0]} to {t[-1]})>'
def __getitem__(self, i):
data = self.get_components(i)[:2]
return [d.astype(np.float32) for d in data]
def get_rows(self, i):
"""
Output pandas dataframes for debug/display purposes. Slower
"""
x, y, time = self.get_components(i)
t_past = self.df.index[i:i+self.window_past]
x_past = pd.DataFrame(x, columns=self._x_cols, index=t_past)
y_future = pd.DataFrame(y, columns=self.columns_target, index=[time])
return x_past, y_future
# Model
from tsai.models.InceptionTimePlus import InceptionTimePlus17x17, InceptionTimePlus32x32, InceptionTimePlus47x47, InceptionTimePlus62x62
from tsai.models.TCN import TCN
class InceptionTimeSeq(nn.Module):
def __init__(
self,
x_dim,
y_dim,
):
super().__init__()
self.inc_block = InceptionTimePlus32x32(
c_in=x_dim,
c_out=y_dim,
ks=[3, 13, 39],
flatten=False,
fc_dropout=0.
padding='causal',
)
def forward(self, x):
return self.inc_block(x.permute(0, 2, 1))
import lightning as pl
from pytorch_optimizer import create_optimizer
from torchmetrics.functional import accuracy
from torch import optim
class PL_MODEL(pl.LightningModule):
def __init__(self, model, num_iterations, lr=3e-4, weight_decay=0, ):
super().__init__()
self._model = model
self.save_hyperparameters(ignore=['model'])
def forward(self, x):
return self._model(x)
def _shared_step(self, batch, batch_idx, phase='train'):
x, y = batch
y_pred = self.forward(x).squeeze(-1)
if phase=='pred':
return y_pred
loss = F.smooth_l1_loss(y_pred, y)
self.log(f"{phase}/loss", loss, on_epoch=True, on_step=True, prog_bar=True)
self.log_dict({
f"{phase}/acc": accuracy(y_pred, y, "binary"),
f"{phase}/loss": loss,
}, on_epoch=True, on_step=False),
return loss
def training_step(self, batch, batch_idx):
return self._shared_step(batch, batch_idx, phase='train')
def validation_step(self, batch, batch_idx):
return self._shared_step(batch, batch_idx, phase='val')
def test_step(self, batch, batch_idx, dataloader_idx=0):
return self._shared_step(batch, batch_idx, phase='test')
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self._shared_step(batch, batch_idx, phase='pred').float().cpu().detach()
def configure_optimizers(self):
"""simple vanilla torch optim"""
# https://lightning.ai/docs/fabric/stable/fundamentals/precision.html
# optimizer = bnb.optim.AdamW8bit(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
# https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer, self.hparams.lr, total_steps=self.hparams.num_iterations, verbose=True,
)
lr_scheduler = {'scheduler': scheduler, 'interval': 'step'}
return [optimizer], [lr_scheduler]
# def configure_optimizers(self):
# optimizer = Ranger21(
# self.parameters(),
# lr=self.hparams.lr,
# weight_decay=self.hparams.weight_decay,
# num_iterations=self.hparams.total_steps,
# )
# return optimizer
# Train
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import LearningRateMonitor
num_workers = 0
batch_size = 32
# Loaders
to_ds = lambda X, y: TensorDataset(to_tensor(X), torch.from_numpy(y).float())
class imdbHSDataModule(pl.LightningDataModule):
def __init__(self,
df: pd.DataFrame,
batch_size: int=32,
):
super().__init__()
self.save_hyperparameters(ignore=["df"])
self.df = df
def setup(self, stage: str):
h = self.hparams
self.X = self.df['X']
self.y = self.df['y']
n = len(self.y)
self.splits = {
'train': (0, int(n * 0.5)),
'val': (int(n * 0.5), int(n * 0.75)),
'test': (int(n * 0.75), n),
}
self.datasets = {key: to_ds(self.X[start:end], self.y[start:end]) for key, (start, end) in splits.items()}
def create_dataloader(self, ds, shuffle=False):
# WARNING: don't put slow logic in here as pl remakes the dl each epoch
return DataLoader(ds, batch_size=self.hparams.batch_size, drop_last=True, shuffle=shuffle)
def train_dataloader(self):
return self.create_dataloader(self.datasets['train'], shuffle=True)
def val_dataloader(self):
return self.create_dataloader(self.datasets['val'])
def test_dataloader(self):
return self.create_dataloader(self.datasets['test'])
dm = imdbHSDataModule(df, batch_size=batch_size)
dm.setup('train')
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()
dl_test = dm.test_dataloader()
# model
pt_model = InceptionTimeSeq(xs, ys)
model_name = type(pt_model).__name__
max_batches = min(len(dl_train), 1000000)
max_epochs = 180
# Wrap in lightning
model = PL_MODEL(pt_model,
weight_decay=1e-1,
lr=1e-3,
num_iterations=max_batches*max_epochs
)
save_dir = f"../outputs/{timestamp}/{model_name}"
Path(save_dir).mkdir(exist_ok=True, parents=True)
trainer = pl.Trainer(
max_epochs=max_epochs,
# limit_train_batches=max_batches,
# limit_val_batches=max_batches//5,
gradient_clip_val=20,
precision="bf16-mixed",
log_every_n_steps=1,
# callbacks=[LearningRateMonitor(logging_interval='step')],
# logger=[CSVLogger(name=model_name, save_dir=save_dir, flush_logs_every_n_steps=5),],
# default_root_dir=save_dir,
)
# train
trainer.fit(model, dl_train, dl_val)
# test
y_preds = trainer.predict(model, dataloaders=dl_test)
y_pred = torch.concat(y_preds)[:, 0].numpy()
y_pred
# Hist
def read_metrics_csv(metrics_file_path):
df_hist = pd.read_csv(metrics_file_path)
df_hist["epoch"] = df_hist["epoch"].ffill()
df_histe = df_hist.set_index("epoch").groupby("epoch").last().ffill().bfill()
return df_histe
def plot_hist(df_hist, allowlist=None, logy=False):
"""plot groups of suffixes together"""
suffixes = list(set([c.split('/')[-1] for c in df_hist.columns if '/' in c]))
for suffix in suffixes:
if allowlist and suffix not in allowlist: continue
df_hist[[c for c in df_hist.columns if c.endswith(suffix) and '/' in c]].plot(title=suffix, style='.', logy=logy)
plt.title(suffix)
plt.show()
# def read_hist(trainer: pl.Trainer):
# ts = [t for t in trainer.loggers if isinstance(t, CSVLogger)]
# try:
# metrics_file_path = Path(ts[0].experiment.metrics_file_path)
# df_histe = read_metrics_csv(metrics_file_path)
# return df_histe
# except Exception as e:
# print(e)
# df_hist = read_hist(trainer).bfill().ffill()
df_hist = read_metrics_csv(trainer.logger.experiment.metrics_file_path).bfill().ffill()
plot_hist(df_hist, ['loss', 'acc', 'auroc'])
display(df_hist)
# test
def _transform_dl_k(k: str) -> str:
"""
>>> _transform_dl_k('test/loss_epoch/dataloader_idx_0') -> "val"
"""
p = re.match(r"test\/(.+)\/dataloader_idx_\d", k)
return p.group(1) if p else k
def rename_pl_test_results(rs: List[Dict[str, float]], ks=["train", "val", "test"], verbose=True):
"""
pytorch lighting test outputs `List of dictionaries with metrics logged during the test phase` where the dataloaders are named `test/val/dataloader_idx_0` etc. This renames them to `val` etc.
usage:
rs = trainer3.test(net, dataloaders=[dl_train, dl_val, dl_test, dl_ood])
df_rs = rename_pl_test_results(rs, ["train", "val", "test", "ood"])
"""
rs = {
ks[i]: {_transform_dl_k
(k): v for k, v in rs[i].items()} for i in range(len(ks))
}
if verbose:
print(pd.DataFrame(rs).round(3).to_markdown())
return pd.DataFrame(rs)
dl_test = dm.test_dataloader()
rs = trainer3.test(net, dataloaders=[dl_train, dl_val, dl_test])
df_rs = rename_pl_test_results(rs, ["train", "val", "test"])
# predict
y_test_pred = trainer.predict(net, dl_test)
y_test_pred = torch.concatenate(y_test_pred).numpy()