Skip to content

Instantly share code, notes, and snippets.

@rgtjf
Created June 30, 2024 14:02
Show Gist options
  • Save rgtjf/15165eecb0c1e0b8c364ab2788957b51 to your computer and use it in GitHub Desktop.
Save rgtjf/15165eecb0c1e0b8c364ab2788957b51 to your computer and use it in GitHub Desktop.
The quantile monitor monitors the input and output, as well as simple transforms to them. It logs the quantile values needed. Link to paper reading paper: - Small-scale proxies for large-scale Transformer training instabilities - Mitchell Wortsman et al. - https://arxiv.org/abs/2309.14322 - notion link: https://www.notion.so/nyonic/Small-scale-p…
"""
The quantile monitor monitors the input and output, as well as simple transforms to them.
It logs the quantile values needed.
Link to paper reading paper:
- Small-scale proxies for large-scale Transformer training instabilities
- Mitchell Wortsman et al.
- https://arxiv.org/abs/2309.14322
- notion link: https://www.notion.so/nyonic/Small-scale-proxies-for-large-scale-Transformer-training-instabilities-95f7d37711f34d8ebae4f505bc160830 # noqa
"""
from __future__ import annotations
import logging
import typing as t
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch import callbacks
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.utils.hooks import RemovableHandle
logger = logging.getLogger(__file__)
def sampling_elements_from_tensor(
x: torch.Tensor, n: int, rng_seed: int = None
) -> torch.Tensor:
"""
Sample n unique elements from a PyTorch tensor x using a numpy random generator.
If a 'rng_seed' is provided, the sampling will be deterministic.
Args:
x (torch.Tensor): The input tensor from which elements are to be sampled.
n (int): The number of elements to sample.
rng_seed (int, optional): The seed for the random number generator.
Default is None, which results in non-deterministic behavior.
Returns:
torch.Tensor: A tensor containing n sampled elements.
"""
x_flat = x.flatten()
if torch.numel(x_flat) > n:
rng = np.random.Generator(np.random.MT19937(seed=rng_seed))
indices = rng.choice(torch.numel(x_flat), n, replace=False)
x_flat = x_flat[indices]
return x_flat
class QuantileMonitorCallback(callbacks.Callback):
"""Callback to log quantile values of torch module input/output during training."""
# TODO: https://github.com/pytorch/pytorch/issues/64947
max_elem_number = 16_000_000
def __init__(
self,
log_every_n_steps: int = 1,
) -> None:
"""Initialize the QuantileMonitorCallback.
Args:
log_every_n_steps: int. Determining the logging frequency. When we set a
trainer to run `max_steps`, then the `global_steps` would be it. And the
biggest `batch_idx` equals to `accumulate_grad_batches` x
`trainer.global_steps`. Overall, the monitor would log `max_steps` /
`log_every_n_steps` times.
"""
super().__init__()
self.log_every_n_steps = log_every_n_steps
self.log_step_flag = False
self.qt_vals = torch.asarray(
[0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1],
dtype=torch.float,
)
self.qt_names = [f"QT_{val:.2f}" for val in self.qt_vals]
self._handles: t.List[RemovableHandle] = []
self.metrics: t.Dict[str, t.Any] = {}
def _get_lmhead_hook_fn(self, name: str) -> t.Callable:
"""Log quantiles to the logger."""
def _hook(
mod: torch.nn.Module,
inp: torch.Tensor,
oup: torch.Tensor,
) -> None:
y = oup
y = sampling_elements_from_tensor(y, self.max_elem_number, 0)
y = y.to(device=self.qt_vals.device, dtype=torch.float)
oup_qts = torch.quantile(y, self.qt_vals).tolist()
for qt, oup_qt in zip(self.qt_names, oup_qts):
self.metrics[f"quantile/{name}.{qt}"] = oup_qt
return _hook
def _get_mha_hook_fn(self, name: str) -> t.Callable:
"""Log quantiles to the logger."""
def _hook(
mod: torch.nn.Module,
inp: torch.Tensor,
oup: torch.Tensor,
) -> None:
y = inp[0]
y = sampling_elements_from_tensor(y, self.max_elem_number, 0)
y = y.to(device=self.qt_vals.device, dtype=torch.float)
oup_qts = torch.quantile(y, self.qt_vals).tolist()
for qt, oup_qt in zip(self.qt_names, oup_qts):
self.metrics[f"quantile/{name}.{qt}"] = oup_qt
return _hook
@rank_zero_only
def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: t.Any,
batch_idx: int,
) -> None:
"""Set up the module quantile handles."""
self.log_step_flag = (
not trainer.fit_loop._should_accumulate()
and trainer.global_step % self.log_every_n_steps == 0
)
if self.log_step_flag:
if len(self._handles) > 0:
for _h in self._handles:
_h.remove()
self.metrics = {}
for name, module in pl_module.named_modules():
if name == "model.linear":
_h = module.register_forward_hook(
self._get_lmhead_hook_fn(name),
)
self._handles.append(_h)
elif name.endswith("self_attn"):
_h = module.register_forward_hook(
self._get_mha_hook_fn(name),
)
self._handles.append(_h)
@rank_zero_only
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: t.Any,
batch_idx: int,
) -> None:
"""Tear down the handles."""
if self.log_step_flag:
# `log_metrics` would add global_step automatically. When passing `step`
# into it would cause confusion with the actual global_step.
trainer.logger.log_metrics(self.metrics)
for _h in self._handles:
_h.remove()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment