Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active December 26, 2022 00:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/7e0ce4c0ed7aa2550a10157b04d46539 to your computer and use it in GitHub Desktop.
Save wassname/7e0ce4c0ed7aa2550a10157b04d46539 to your computer and use it in GitHub Desktop.
pytorch lightning minimal training loop
from typing import Callable
from loguru import logger
import torch
from torch import nn
def convert_layers(model: nn.Module, original: nn.Module, value: bool):
"""
Turn dropout on
"""
for child_name, child in model.named_children():
if isinstance(child, original):
logger.debug(f"{child} switched from {child.training} to {value}")
child.train(value)
else:
convert_layers(child, original, value)
def to_mcdropout(model):
logger.debug('enabling mcdropout')
model.eval()
convert_layers(model, torch.nn.modules.dropout.Dropout2d, True)
convert_layers(model, torch.nn.modules.dropout.Dropout, True)
convert_layers(model, torch.nn.modules.dropout.Dropout3d, True)
convert_layers(model, torch.nn.modules.dropout.Dropout1d, True)
from convcnp.data.datamodule import PriceDataModule
from torch import nn
import torch
import torch.nn.functional as F
import torchmetrics
from ranger21 import Ranger21
from pathlib import Path
from loguru import logger
# # Lightning
import pytorch_lightning as pl
from .utils import cat_tensordicts
# %%
class PLTimeGrad(pl.LightningModule):
def __init__(
self,
num_batches_per_epoch,
in_channels=1,
out_channels=2,
epochs=30,
lr=1e-3,
weight_decay=1e-6,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
h = self.hparams
self.net = Net()
self.mse = torchmetrics.MeanSquaredError()
def forward(
self,
past_time_feat,
past_target,
future_time_feat,
):
r = self.net(
t_context=past_time_feat,
y_context=past_target,
t_target=future_time_feat,
)
return r
def _step(self, b, batch_idx, step="train"):
# [past_dts, past_time_feat, past_input_values, past_target, future_dts, future_time_feat, future_target] = b.values()
y_mean, y_std = self(b['past_time_feat'], b['past_target'],
b['future_time_feat'])
loss = -gaussian_logpdf(b['future_target'], y_mean, y_std,
'batched_mean')
self.mse(y_mean, y_true)
self.log_dict({
f"loss/{step}": loss.detach().item(),
f"mse/{step}": self.mse,
})
assert torch.isfinite(loss)
return loss
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, step="train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, step="val")
def predict_step(self, b, batch_idx, **kwargs):
"""return prediction from batch"""
y_mean, y_std = self(b['past_time_feat'], b['past_target'],
b['future_time_feat'])
return {
"y_mean": y_mean,
"y_std": y_std,
"past_dts": b['past_dts'],
"future_target": b["future_target"],
"past_target": b["past_target"],
"future_dts": b['future_dts'],
}
def configure_optimizers(self):
return Ranger21(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
num_epochs=self.hparams.epochs,
num_batches_per_epoch=self.hparams.num_batches_per_epoch,
)
# %%
# datamodule
dm = PriceDataModule(batch_size=128, limit_train_iters=19000)
dm.setup()
ds = dm.ds_test.datasets[0]
c_in = ds.data_x.shape[1]
# %%
pl_net = PLTimeGrad(
# data
c_in=c_in,
c_out = 1,
num_batches_per_epoch=dm.num_batches_per_epoch,
decoder_seq_len=dm.hparams.pred_len,
encoder_seq_len=dm.hparams.seq_len,
# train
epochs=50,
lr=2e-3,
weight_decay=1e-6,
)
pl_net
# %%
# QC
dl = dm.val_dataloader()
b = next(iter(dl))
b = {k:bb.to(device) for k,bb in b.items()}
pl_net.to(device)
with torch.no_grad():
r = pl_net.training_step(b, 0).cpu()
r
# %%
from torchinfo import summary
ms = summary(pl_net.net,
input_data=[
b["past_time_feat"],
b["past_target"],
b["future_time_feat"],
],
depth=3,
col_names= ["input_size",
"output_size",
"num_params",
# "params_percent",
# "kernel_size",
# "mult_adds",
# "trainable"
],
col_width=15,
device='cuda')
ms
# %%
# ## Trainer
# make the save dir
timestamp = pd.Timestamp.utcnow().strftime('%Y%m%d_%H-%M-%S')
save_dir = Path(f"../outputs/{timestamp}")
# trainer
trainer = pl.Trainer(
accelerator="gpu",
# Training length
max_epochs=pl_net.hparams.epochs,
limit_train_batches=dm.num_batches_per_epoch,
limit_val_batches=dm.num_batches_per_epoch // 5,
# callbacks and loggers
default_root_dir=save_dir,
)
# train
trainer.fit(pl_net, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())
# %%
# more complicated one
# # ## Trainer
# import ipynbname
# timestamp = pd.Timestamp.utcnow().strftime('%Y%m%d_%H-%M-%S')
# nb_fname = ipynbname.name()
# save_dir = Path(f"../outputs/{timestamp}_{nb_fname}")
# log_every_n_steps=1
# assert (dm.num_batches_per_epoch // 5)>log_every_n_steps, f'too few batches {(dm.num_batches_per_epoch // 5)}>{log_every_n_steps}' # val batches > batches
# callbacks=[
# ModelCheckpoint(
# monitor="loss/val", every_n_epochs=1, save_last=True
# ),
# ]
# trainer = pl.Trainer(
# # Training length
# max_epochs=pl_net.hparams.epochs,
# limit_train_batches=dm.num_batches_per_epoch,
# limit_val_batches=dm.num_batches_per_epoch // 5,
# # GPU
# # gpus=1,
# accelerator="gpu",
# # Logging
# default_root_dir=save_dir,
# logger=loggers,
# log_every_n_steps=log_every_n_steps,
# # Callbacks
# callbacks=callbacks,
# )
# %%
from tbparse import SummaryReader
reader = SummaryReader(trainer.logger.log_dir), pivot=True)
df_hist = reader.scalars.drop(columns=['hp_metric', 'epoch']).set_index('step')
df_hist.ffill().plot()
# %%
preds = trainer.predict(model=pl_net, dataloaders=dm.val_dataloader())
preds = cat_tensordicts(preds)
preds.keys()
import pytorch_lightning as pl
from pathlib import Path
import pandas as pd
import torch
def cat_tensordicts(r):
ks = r[0].keys()
return {k: torch.concatenate([rr[k] for rr in r], 0).detach().cpu() for k in ks}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment