Last active
January 12, 2024 23:40
-
-
Save wassname/e29d02b5026a531e13912cf768e6fdc8 to your computer and use it in GitHub Desktop.
This is my cheatsheet, for my current best practices for using pytorch lightning `lightning_start.py`. This is verbose so that I can delete what is not needed. I mainly log to csv to keep things simple.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a template for starting with pytorch lightning, it includes many extra things because it's easier to delete than reinvent. | |
It is written for these versions: | |
- lightning==2.0.2 | |
- pytorch-optimizer==2.8.0 | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data | |
import lightning.pytorch as pl | |
import pandas as pd | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from pathlib import Path | |
from IPython.display import display, HTML | |
import warnings | |
import logging | |
import os | |
plt.style.use('ggplot') | |
torch.set_float32_matmul_precision('medium') | |
warnings.filterwarnings("ignore", ".*does not have many workers.*") | |
# warnings.filterwarnings("ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*") | |
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) | |
os.environ['TQDM_MININTERVAL'] = '5' | |
class SeqDataSet(torch.utils.data.Dataset): | |
"""fast windowed access to a dataframe.""" | |
def __init__(self, df: pd.DataFrame, window_past=40, columns_target=['energy(kWh/hh)']): | |
self.df = df | |
self.window_past = window_past | |
self.columns_target = columns_target | |
# It's 100x faster to work with np | |
self._x = self.df.values | |
self._y = self.df[columns_target].values | |
self._x_cols = self.df.columns | |
def get_components(self, i): | |
i = len(self)+i if i < 0 else i | |
x = self._x[i : i + self.window_past].copy() | |
y = self._y[i + self.window_past+ 1].copy() | |
time = self.df.index.values[i + self.window_past+ 1].copy() | |
return x, y, time | |
def __len__(self): | |
return len(self._x) - (self.window_past + 1) | |
def __repr__(self): | |
t = self.df.index | |
return f'<{type(self).__name__}(shape={self.df.shape}, times={t[0]} to {t[-1]})>' | |
def __getitem__(self, i): | |
data = self.get_components(i)[:2] | |
return [d.astype(np.float32) for d in data] | |
def get_rows(self, i): | |
""" | |
Output pandas dataframes for debug/display purposes. Slower | |
""" | |
x, y, time = self.get_components(i) | |
t_past = self.df.index[i:i+self.window_past] | |
x_past = pd.DataFrame(x, columns=self._x_cols, index=t_past) | |
y_future = pd.DataFrame(y, columns=self.columns_target, index=[time]) | |
return x_past, y_future | |
# Model | |
from tsai.models.InceptionTimePlus import InceptionTimePlus17x17, InceptionTimePlus32x32, InceptionTimePlus47x47, InceptionTimePlus62x62 | |
from tsai.models.TCN import TCN | |
class InceptionTimeSeq(nn.Module): | |
def __init__( | |
self, | |
x_dim, | |
y_dim, | |
): | |
super().__init__() | |
self.inc_block = InceptionTimePlus32x32( | |
c_in=x_dim, | |
c_out=y_dim, | |
ks=[3, 13, 39], | |
flatten=False, | |
fc_dropout=0. | |
padding='causal', | |
) | |
def forward(self, x): | |
return self.inc_block(x.permute(0, 2, 1)) | |
import lightning as pl | |
from pytorch_optimizer import create_optimizer | |
from torchmetrics.functional import accuracy | |
from torch import optim | |
class PL_MODEL(pl.LightningModule): | |
def __init__(self, model, num_iterations, lr=3e-4, weight_decay=0, ): | |
super().__init__() | |
self._model = model | |
self.save_hyperparameters(ignore=['model']) | |
def forward(self, x): | |
return self._model(x) | |
def _shared_step(self, batch, batch_idx, phase='train'): | |
x, y = batch | |
y_pred = self.forward(x).squeeze(-1) | |
if phase=='pred': | |
return y_pred | |
loss = F.smooth_l1_loss(y_pred, y) | |
self.log(f"{phase}/loss", loss, on_epoch=True, on_step=True, prog_bar=True) | |
self.log_dict({ | |
f"{phase}/acc": accuracy(y_pred, y, "binary"), | |
f"{phase}/loss": loss, | |
}, on_epoch=True, on_step=False), | |
return loss | |
def training_step(self, batch, batch_idx): | |
return self._shared_step(batch, batch_idx, phase='train') | |
def validation_step(self, batch, batch_idx): | |
return self._shared_step(batch, batch_idx, phase='val') | |
def test_step(self, batch, batch_idx, dataloader_idx=0): | |
return self._shared_step(batch, batch_idx, phase='test') | |
def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
return self._shared_step(batch, batch_idx, phase='pred').float().cpu().detach() | |
def configure_optimizers(self): | |
"""simple vanilla torch optim""" | |
# https://lightning.ai/docs/fabric/stable/fundamentals/precision.html | |
# optimizer = bnb.optim.AdamW8bit(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) | |
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) | |
# https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers | |
scheduler = optim.lr_scheduler.OneCycleLR( | |
optimizer, self.hparams.lr, total_steps=self.hparams.num_iterations, verbose=True, | |
) | |
lr_scheduler = {'scheduler': scheduler, 'interval': 'step'} | |
return [optimizer], [lr_scheduler] | |
# def configure_optimizers(self): | |
# optimizer = Ranger21( | |
# self.parameters(), | |
# lr=self.hparams.lr, | |
# weight_decay=self.hparams.weight_decay, | |
# num_iterations=self.hparams.total_steps, | |
# ) | |
# return optimizer | |
# Train | |
from torch.utils.data import DataLoader | |
from lightning.pytorch.loggers import CSVLogger | |
from lightning.pytorch.callbacks import LearningRateMonitor | |
num_workers = 0 | |
batch_size = 32 | |
# Loaders | |
to_ds = lambda X, y: TensorDataset(to_tensor(X), torch.from_numpy(y).float()) | |
class imdbHSDataModule(pl.LightningDataModule): | |
def __init__(self, | |
df: pd.DataFrame, | |
batch_size: int=32, | |
): | |
super().__init__() | |
self.save_hyperparameters(ignore=["df"]) | |
self.df = df | |
def setup(self, stage: str): | |
h = self.hparams | |
self.X = self.df['X'] | |
self.y = self.df['y'] | |
n = len(self.y) | |
self.splits = { | |
'train': (0, int(n * 0.5)), | |
'val': (int(n * 0.5), int(n * 0.75)), | |
'test': (int(n * 0.75), n), | |
} | |
self.datasets = {key: to_ds(self.X[start:end], self.y[start:end]) for key, (start, end) in splits.items()} | |
def create_dataloader(self, ds, shuffle=False): | |
# WARNING: don't put slow logic in here as pl remakes the dl each epoch | |
return DataLoader(ds, batch_size=self.hparams.batch_size, drop_last=True, shuffle=shuffle) | |
def train_dataloader(self): | |
return self.create_dataloader(self.datasets['train'], shuffle=True) | |
def val_dataloader(self): | |
return self.create_dataloader(self.datasets['val']) | |
def test_dataloader(self): | |
return self.create_dataloader(self.datasets['test']) | |
dm = imdbHSDataModule(df, batch_size=batch_size) | |
dm.setup('train') | |
dl_train = dm.train_dataloader() | |
dl_val = dm.val_dataloader() | |
dl_test = dm.test_dataloader() | |
# model | |
pt_model = InceptionTimeSeq(xs, ys) | |
model_name = type(pt_model).__name__ | |
max_batches = min(len(dl_train), 1000000) | |
max_epochs = 180 | |
# Wrap in lightning | |
model = PL_MODEL(pt_model, | |
weight_decay=1e-1, | |
lr=1e-3, | |
num_iterations=max_batches*max_epochs | |
) | |
save_dir = f"../outputs/{timestamp}/{model_name}" | |
Path(save_dir).mkdir(exist_ok=True, parents=True) | |
trainer = pl.Trainer( | |
max_epochs=max_epochs, | |
# limit_train_batches=max_batches, | |
# limit_val_batches=max_batches//5, | |
gradient_clip_val=20, | |
precision="bf16-mixed", | |
log_every_n_steps=1, | |
# callbacks=[LearningRateMonitor(logging_interval='step')], | |
# logger=[CSVLogger(name=model_name, save_dir=save_dir, flush_logs_every_n_steps=5),], | |
# default_root_dir=save_dir, | |
) | |
# train | |
trainer.fit(model, dl_train, dl_val) | |
# test | |
y_preds = trainer.predict(model, dataloaders=dl_test) | |
y_pred = torch.concat(y_preds)[:, 0].numpy() | |
y_pred | |
# Hist | |
def read_metrics_csv(metrics_file_path): | |
df_hist = pd.read_csv(metrics_file_path) | |
df_hist["epoch"] = df_hist["epoch"].ffill() | |
df_histe = df_hist.set_index("epoch").groupby("epoch").last().ffill().bfill() | |
return df_histe | |
def plot_hist(df_hist, allowlist=None, logy=False): | |
"""plot groups of suffixes together""" | |
suffixes = list(set([c.split('/')[-1] for c in df_hist.columns if '/' in c])) | |
for suffix in suffixes: | |
if allowlist and suffix not in allowlist: continue | |
df_hist[[c for c in df_hist.columns if c.endswith(suffix) and '/' in c]].plot(title=suffix, style='.', logy=logy) | |
plt.title(suffix) | |
plt.show() | |
# def read_hist(trainer: pl.Trainer): | |
# ts = [t for t in trainer.loggers if isinstance(t, CSVLogger)] | |
# try: | |
# metrics_file_path = Path(ts[0].experiment.metrics_file_path) | |
# df_histe = read_metrics_csv(metrics_file_path) | |
# return df_histe | |
# except Exception as e: | |
# print(e) | |
# df_hist = read_hist(trainer).bfill().ffill() | |
df_hist = read_metrics_csv(trainer.logger.experiment.metrics_file_path).bfill().ffill() | |
plot_hist(df_hist, ['loss', 'acc', 'auroc']) | |
display(df_hist) | |
# test | |
def _transform_dl_k(k: str) -> str: | |
""" | |
>>> _transform_dl_k('test/loss_epoch/dataloader_idx_0') -> "val" | |
""" | |
p = re.match(r"test\/(.+)\/dataloader_idx_\d", k) | |
return p.group(1) if p else k | |
def rename_pl_test_results(rs: List[Dict[str, float]], ks=["train", "val", "test"], verbose=True): | |
""" | |
pytorch lighting test outputs `List of dictionaries with metrics logged during the test phase` where the dataloaders are named `test/val/dataloader_idx_0` etc. This renames them to `val` etc. | |
usage: | |
rs = trainer3.test(net, dataloaders=[dl_train, dl_val, dl_test, dl_ood]) | |
df_rs = rename_pl_test_results(rs, ["train", "val", "test", "ood"]) | |
""" | |
rs = { | |
ks[i]: {_transform_dl_k | |
(k): v for k, v in rs[i].items()} for i in range(len(ks)) | |
} | |
if verbose: | |
print(pd.DataFrame(rs).round(3).to_markdown()) | |
return pd.DataFrame(rs) | |
dl_test = dm.test_dataloader() | |
rs = trainer3.test(net, dataloaders=[dl_train, dl_val, dl_test]) | |
df_rs = rename_pl_test_results(rs, ["train", "val", "test"]) | |
# predict | |
y_test_pred = trainer.predict(net, dl_test) | |
y_test_pred = torch.concatenate(y_test_pred).numpy() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See also examples:
docs: