Created
June 30, 2024 14:02
-
-
Save rgtjf/15165eecb0c1e0b8c364ab2788957b51 to your computer and use it in GitHub Desktop.
The quantile monitor monitors the input and output, as well as simple transforms to them. It logs the quantile values needed. Link to paper reading paper: - Small-scale proxies for large-scale Transformer training instabilities - Mitchell Wortsman et al. - https://arxiv.org/abs/2309.14322 - notion link: https://www.notion.so/nyonic/Small-scale-p…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The quantile monitor monitors the input and output, as well as simple transforms to them. | |
It logs the quantile values needed. | |
Link to paper reading paper: | |
- Small-scale proxies for large-scale Transformer training instabilities | |
- Mitchell Wortsman et al. | |
- https://arxiv.org/abs/2309.14322 | |
- notion link: https://www.notion.so/nyonic/Small-scale-proxies-for-large-scale-Transformer-training-instabilities-95f7d37711f34d8ebae4f505bc160830 # noqa | |
""" | |
from __future__ import annotations | |
import logging | |
import typing as t | |
import lightning.pytorch as pl | |
import numpy as np | |
import torch | |
from lightning.pytorch import callbacks | |
from lightning.pytorch.utilities.rank_zero import rank_zero_only | |
from lightning.pytorch.utilities.types import STEP_OUTPUT | |
from torch.utils.hooks import RemovableHandle | |
logger = logging.getLogger(__file__) | |
def sampling_elements_from_tensor( | |
x: torch.Tensor, n: int, rng_seed: int = None | |
) -> torch.Tensor: | |
""" | |
Sample n unique elements from a PyTorch tensor x using a numpy random generator. | |
If a 'rng_seed' is provided, the sampling will be deterministic. | |
Args: | |
x (torch.Tensor): The input tensor from which elements are to be sampled. | |
n (int): The number of elements to sample. | |
rng_seed (int, optional): The seed for the random number generator. | |
Default is None, which results in non-deterministic behavior. | |
Returns: | |
torch.Tensor: A tensor containing n sampled elements. | |
""" | |
x_flat = x.flatten() | |
if torch.numel(x_flat) > n: | |
rng = np.random.Generator(np.random.MT19937(seed=rng_seed)) | |
indices = rng.choice(torch.numel(x_flat), n, replace=False) | |
x_flat = x_flat[indices] | |
return x_flat | |
class QuantileMonitorCallback(callbacks.Callback): | |
"""Callback to log quantile values of torch module input/output during training.""" | |
# TODO: https://github.com/pytorch/pytorch/issues/64947 | |
max_elem_number = 16_000_000 | |
def __init__( | |
self, | |
log_every_n_steps: int = 1, | |
) -> None: | |
"""Initialize the QuantileMonitorCallback. | |
Args: | |
log_every_n_steps: int. Determining the logging frequency. When we set a | |
trainer to run `max_steps`, then the `global_steps` would be it. And the | |
biggest `batch_idx` equals to `accumulate_grad_batches` x | |
`trainer.global_steps`. Overall, the monitor would log `max_steps` / | |
`log_every_n_steps` times. | |
""" | |
super().__init__() | |
self.log_every_n_steps = log_every_n_steps | |
self.log_step_flag = False | |
self.qt_vals = torch.asarray( | |
[0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1], | |
dtype=torch.float, | |
) | |
self.qt_names = [f"QT_{val:.2f}" for val in self.qt_vals] | |
self._handles: t.List[RemovableHandle] = [] | |
self.metrics: t.Dict[str, t.Any] = {} | |
def _get_lmhead_hook_fn(self, name: str) -> t.Callable: | |
"""Log quantiles to the logger.""" | |
def _hook( | |
mod: torch.nn.Module, | |
inp: torch.Tensor, | |
oup: torch.Tensor, | |
) -> None: | |
y = oup | |
y = sampling_elements_from_tensor(y, self.max_elem_number, 0) | |
y = y.to(device=self.qt_vals.device, dtype=torch.float) | |
oup_qts = torch.quantile(y, self.qt_vals).tolist() | |
for qt, oup_qt in zip(self.qt_names, oup_qts): | |
self.metrics[f"quantile/{name}.{qt}"] = oup_qt | |
return _hook | |
def _get_mha_hook_fn(self, name: str) -> t.Callable: | |
"""Log quantiles to the logger.""" | |
def _hook( | |
mod: torch.nn.Module, | |
inp: torch.Tensor, | |
oup: torch.Tensor, | |
) -> None: | |
y = inp[0] | |
y = sampling_elements_from_tensor(y, self.max_elem_number, 0) | |
y = y.to(device=self.qt_vals.device, dtype=torch.float) | |
oup_qts = torch.quantile(y, self.qt_vals).tolist() | |
for qt, oup_qt in zip(self.qt_names, oup_qts): | |
self.metrics[f"quantile/{name}.{qt}"] = oup_qt | |
return _hook | |
@rank_zero_only | |
def on_train_batch_start( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
batch: t.Any, | |
batch_idx: int, | |
) -> None: | |
"""Set up the module quantile handles.""" | |
self.log_step_flag = ( | |
not trainer.fit_loop._should_accumulate() | |
and trainer.global_step % self.log_every_n_steps == 0 | |
) | |
if self.log_step_flag: | |
if len(self._handles) > 0: | |
for _h in self._handles: | |
_h.remove() | |
self.metrics = {} | |
for name, module in pl_module.named_modules(): | |
if name == "model.linear": | |
_h = module.register_forward_hook( | |
self._get_lmhead_hook_fn(name), | |
) | |
self._handles.append(_h) | |
elif name.endswith("self_attn"): | |
_h = module.register_forward_hook( | |
self._get_mha_hook_fn(name), | |
) | |
self._handles.append(_h) | |
@rank_zero_only | |
def on_train_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
outputs: STEP_OUTPUT, | |
batch: t.Any, | |
batch_idx: int, | |
) -> None: | |
"""Tear down the handles.""" | |
if self.log_step_flag: | |
# `log_metrics` would add global_step automatically. When passing `step` | |
# into it would cause confusion with the actual global_step. | |
trainer.logger.log_metrics(self.metrics) | |
for _h in self._handles: | |
_h.remove() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment