Created June 30, 2024
The quantile monitor monitors the input and output, as well as simple transforms to them.
It logs the quantile values needed.
Link to paper reading paper:
- Small-scale proxies for large-scale Transformer training instabilities
- Mitchell Wortsman et al.
- notion link: # noqa
from __future__ import annotations
import logging
import typing as t
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch import callbacks
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.utils.hooks import RemovableHandle
logger = logging.getLogger(__file__)
def sampling_elements_from_tensor(
x: torch.Tensor, n: int, rng_seed: int = None
) -> torch.Tensor:
Sample n unique elements from a PyTorch tensor x using a numpy random generator.
If a 'rng_seed' is provided, the sampling will be deterministic.
x (torch.Tensor): The input tensor from which elements are to be sampled.
n (int): The number of elements to sample.
rng_seed (int, optional): The seed for the random number generator.
Default is None, which results in non-deterministic behavior.
torch.Tensor: A tensor containing n sampled elements.
x_flat = x.flatten()
if torch.numel(x_flat) > n:
rng = np.random.Generator(np.random.MT19937(seed=rng_seed))
indices = rng.choice(torch.numel(x_flat), n, replace=False)
x_flat = x_flat[indices]
return x_flat
class QuantileMonitorCallback(callbacks.Callback):
"""Callback to log quantile values of torch module input/output during training."""
max_elem_number = 16_000_000
def __init__(
log_every_n_steps: int = 1,
) -> None:
"""Initialize the QuantileMonitorCallback.
log_every_n_steps: int. Determining the logging frequency. When we set a
trainer to run `max_steps`, then the `global_steps` would be it. And the
biggest `batch_idx` equals to `accumulate_grad_batches` x
`trainer.global_steps`. Overall, the monitor would log `max_steps` /
`log_every_n_steps` times.
self.log_every_n_steps = log_every_n_steps
self.log_step_flag = False
self.qt_vals = torch.asarray(
[0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1],
self.qt_names = [f"QT_{val:.2f}" for val in self.qt_vals]
self._handles: t.List[RemovableHandle] = []
self.metrics: t.Dict[str, t.Any] = {}
def _get_lmhead_hook_fn(self, name: str) -> t.Callable:
"""Log quantiles to the logger."""
def _hook(
mod: torch.nn.Module,
inp: torch.Tensor,
oup: torch.Tensor,
) -> None:
y = oup
y = sampling_elements_from_tensor(y, self.max_elem_number, 0)
y =, dtype=torch.float)
oup_qts = torch.quantile(y, self.qt_vals).tolist()
for qt, oup_qt in zip(self.qt_names, oup_qts):
self.metrics[f"quantile/{name}.{qt}"] = oup_qt
return _hook
def _get_mha_hook_fn(self, name: str) -> t.Callable:
"""Log quantiles to the logger."""
def _hook(
mod: torch.nn.Module,
inp: torch.Tensor,
oup: torch.Tensor,
) -> None:
y = inp[0]
y = sampling_elements_from_tensor(y, self.max_elem_number, 0)
y =, dtype=torch.float)
oup_qts = torch.quantile(y, self.qt_vals).tolist()
for qt, oup_qt in zip(self.qt_names, oup_qts):
self.metrics[f"quantile/{name}.{qt}"] = oup_qt
return _hook
def on_train_batch_start(
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: t.Any,
batch_idx: int,
) -> None:
"""Set up the module quantile handles."""
self.log_step_flag = (
not trainer.fit_loop._should_accumulate()
and trainer.global_step % self.log_every_n_steps == 0
if self.log_step_flag:
if len(self._handles) > 0:
for _h in self._handles:
self.metrics = {}
for name, module in pl_module.named_modules():
if name == "model.linear":
_h = module.register_forward_hook(
elif name.endswith("self_attn"):
_h = module.register_forward_hook(
def on_train_batch_end(
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: t.Any,
batch_idx: int,
) -> None:
"""Tear down the handles."""
if self.log_step_flag:
# `log_metrics` would add global_step automatically. When passing `step`
# into it would cause confusion with the actual global_step.
for _h in self._handles:
