# This file is largely inspired by and mostly follows the structure of
# ``fairscale.nn.FullyShardedDataParallel`` in
from collections import OrderedDict
import contextlib
from enum import Enum, auto
import functools
import gc
from itertools import chain
import logging
from math import inf
import time
import traceback
from typing import (
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import PackedSequence
import torch_xla.core.xla_model as xm
from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper
from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter
FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
class TrainingState(Enum):
Simple enum to indicate what state FSDP is in. Used for asserting
to make sure APIs are called in the correct state.
BACKWARD_PRE and BACKWARD_POST states are used to ensure we
receives backward hooks in the correct order. It is used to catch
unexpected order of hooks being called (likely due to our
hook registration logic or autograd engine logic changes).
IDLE = auto()
FORWARD = auto()
class XlaFullyShardedDataParallel(nn.Module):
A wrapper for sharding Module parameters across data parallel workers in
PyTorch XLA. XlaFullyShardedDataParallel is commonly shorten to FSDP.
The implementation of this class is largely inspired by and mostly follows
the structure of ``fairscale.nn.FullyShardedDataParallel`` in
Pseudo-code usage::
my_module =
sharded_module = XlaFullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
output = sharded_module(x, y)
loss = output.sum()
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
reduce XLA device memory usage and CPU memory usage when initializing large
models and to improve training speed by overlapping the all-gather step
across the forward pass.
.. warning::
The module should be moved to XLA device *before* wrapping it with
FSDP. For nested FSDP, the inner FSDP modules also need to be on XLA
device before wrapping.
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning::
Please use ``optim.step()`` instead of ``xm.optimizer_step(optim)`` for
optimizer update. The latter averages the gradients across XLA devices,
which is incorrect for FSDP.
.. warning::
When saving checkpoints, the training process on each XLA device needs
to save its own (sharded) model and optimizer state_dict to a different
path. *To consolidate sharded checkpoints later, please also save
``model.get_shard_metadata()``* along with ``model.state_dict()`` and
``optimizer.state_dict()`` as follows:
ckpt = {
'model': model.state_dict(),
'shard_metadata': model.get_shard_metadata(),
'optimizer': optimizer.state_dict(),
ckpt_path = f'/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth', ckpt_path, master_only=False)
When resuming training of an FSDP model from saved checkpoints, all
training processes need to load their corresponding (sharded) model and
optimizer state_dict. Use ``consolidate_sharded_model_checkpoints`` or
run ``python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts``
build a full model state_dict for the original unwrapped module from
the sharded model state_dict.
module (nn.Module):
module to be wrapped with FSDP.
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding
individual layers.
flatten_parameters (bool, Optional):
if ``True``, flatten parameters into a single contiguous tensor for
all_gather and reduce_scatter, which could potentially improve speed.
In this case, one cannot apply separate optimizer groups to different
original parameters in the wrapped module (e.g. setting bias terms or
any BatchNorm submodules to have zero weight decay) since all the
original parameters now become a single concatenated vector.
execute_sharding_on_init (bool, Optional):
if ``True``, immediately execute the parameter sharding via
`xm.mark_step` to free up the memory of the full parameters.
optimization_barrier_in_forward (bool, Optional):
if ``True``, apply `xm.optimization_barrier_` on the FSDP module's
inputs and outputs. This avoids XLA fusion with other forward pass
computation outside the FSDP module and could save additional memory.
optimization_barrier_in_backward (bool, Optional):
if ``True``, apply `xm.optimization_barrier_` on the FSDP module's
backward incoming gradients. This avoids XLA fusion with other
backward pass computation outside the FSDP module and could save
additional memory.
mark_step_on_finalization (bool, Optional):
if ``True``, call `xm.mark_step` upon finalizing gradients in the
root FSDP module. Here in `xm.mark_step` is only called once for the
entire backward pass and should therefore only moderately increase
the execution time. When setting to ``True``, this option may help
prevent undesired fusion in backward pass and save more memory.
disable_reshard_on_root (bool, Optional):
If ``True``, ``reshard_after_forward`` will be set to ``False`` if
the module is a FSDP root module to improve performance. For some
cases, we do not reshard the full parameters of an FSDP root module
since those parameters are needed immediately for the backward pass.
If ``False``, the performance will be lower, but it is needed because
it helps to save memory. Consider a case that an FSDP root module is
a submodule of a model. Backward pass may not start immediate after
the FSDP root module finishes its forward. So, reshard the parameters
for the FSDP root modules can help to save memory in this case.
Default: True.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` but can be set to ``torch.float16`` or
``torch.bfloat16``. The sharded parameters will always be in FP32.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter gradients in FP32. This is only
relevant when *``compute_dtype``* is not ``torch.float32``.
sharding_groups (list, Optional):
If specified, FSDP will use this ``sharding_groups`` for all-gather
and reduce-scatter ops in full parameter construction and gradient
sharding. This can be useful for mixing FSDP with model parallelism
such as Megatron. One must also specify ``sharding_rank`` and
``sharding_world_size`` when using ``sharding_groups``.
sharding_rank (int, Optional):
The rank of this sharding instance. This must be specified if
``sharding_groups`` is provided. Otherwise it defaults to
sharding_world_size (int, Optional):
The world_size of this sharding instance. This must be specified if
``sharding_groups`` is provided. Otherwise it defaults to
def __init__(
module: nn.Module,
reshard_after_forward: bool = True,
flatten_parameters: bool = False,
execute_sharding_on_init: bool = True,
optimization_barrier_in_forward: bool = True,
optimization_barrier_in_backward: bool = True,
mark_step_on_finalization: bool = False,
disable_reshard_on_root: bool = True,
compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
fp32_reduce_scatter: bool = False,
sharding_groups: Optional[List[List[int]]] = None,
sharding_rank: Optional[int] = None,
sharding_world_size: Optional[int] = None,
on_neuron: bool = True,
_shard_size_multiple: int = 128,
_pin_layout_in_all_reduce: bool = False,
_pin_layout_in_all_gather: bool = False,
_pin_layout_in_reduce_scatter: bool = False,
_debug_dummy_forward_pass: bool = False,
_debug_msg: str = "xla_fsdp",
_debug_print: bool = False,
_debug_dummy_all_gather_op: bool = False,
_debug_dummy_all_reduce_op: bool = False,
_debug_dummy_reduce_scatter_op: bool = False,
_debug_dummy_optimization_barrier_op: bool = False,
if isinstance(module, XlaFullyShardedDataParallel):
raise RuntimeError(
"Cannot wrap a module that is already wrapped with FSDP. For nested FSDP, "
"first wrap the inner child modules before wrapping the outer parent module."
is_forward_defined = (
hasattr(module, "forward") and hasattr(module.forward, "__func__") and
module.forward.__func__ != torch.nn.Module.forward)
if not is_forward_defined:
raise RuntimeError(
"The module wrapped by FSDP *must define a `forward` method and call it "
"during the module's forward pass for FSDP to work correctly.* "
"Hence, do not wrap `nn.ModuleList` or `nn.ModuleDict` with FSDP "
"(since they don't have `forward` defined), "
"and do not perform the forward pass in other ways apart from the `forward` method. "
"(i.e. you should directly call the FSDP-wrapped module itself in your code, "
"instead of using any of its submodules or its weights).")
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.flatten_parameters = flatten_parameters
self.optimization_barrier_in_forward = optimization_barrier_in_forward
self.optimization_barrier_in_backward = optimization_barrier_in_backward
self.mark_step_on_finalization = mark_step_on_finalization
if compute_dtype is not None and compute_dtype not in FLOAT_DTYPES:
raise ValueError(
f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}")
self.compute_dtype = compute_dtype or torch.float32
if buffer_dtype is not None and buffer_dtype not in FLOAT_DTYPES:
raise ValueError(
f"buffer_dtype must be one of {FLOAT_DTYPES}, not {buffer_dtype}")
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.fp32_reduce_scatter = fp32_reduce_scatter
# Make sharded parameter sizes a multiple of 128 for efficient all_gather ops on TPUs
# (see for details)
# TODO (ronghanghu): change the default to 1 after is resolved
self._shard_size_multiple = _shard_size_multiple
# Set layout pinning to False in all_gather, all_reduce, and reduce_scatter so that they can work together
# TODO (ronghanghu): change the default layout pinning to True after it's supported simultaneously
# on all collective ops (see for details)
if _debug_dummy_all_gather_op:
self.all_gather_op = dummy_all_gather
self.all_gather_op = functools.partial(
xm.all_gather, pin_layout=_pin_layout_in_all_gather)
if _debug_dummy_all_reduce_op:
self.all_reduce_op = dummy_all_reduce
self.all_reduce_op = functools.partial(
xm.all_reduce, pin_layout=_pin_layout_in_all_reduce)
if _debug_dummy_reduce_scatter_op:
self.reduce_scatter_op = dummy_reduce_scatter
self.reduce_scatter_op = functools.partial(
xm.reduce_scatter, pin_layout=_pin_layout_in_reduce_scatter)
if _debug_dummy_optimization_barrier_op:
self.optimization_barrier_op = lambda *args: None
self.optimization_barrier_op = xm.optimization_barrier_
# Allow specifying groups for the sharding collective ops, useful for mixing
# FSDP data parallelism with model parallelism (e.g. Megatron)
self.sharding_groups = sharding_groups
if sharding_groups is None:
self.rank = xm.get_ordinal()
self.world_size = xm.xrt_world_size()
if sharding_rank is None or sharding_world_size is None:
raise ValueError(
"sharding_rank and sharding_world_size must be provided when sharding_groups is specified"
self.rank = sharding_rank
self.world_size = sharding_world_size
self.on_neuron = on_neuron
# Options for debugging
# - set _debug_dummy_forward_pass=True to check for parameter-only memory consumption
# - set _debug_msg="xxx" and _debug_print=True to distinguish different FSDP instance
self._debug_dummy_forward_pass = _debug_dummy_forward_pass
self._debug_msg = _debug_msg
self._debug_print = _debug_print
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self._tstart = time.time()
# Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to
# shard any leftover parameters.
params = []
for param in module.parameters():
if not hasattr(param, "_is_sharded"):
# For now, it is either all flatten or none flatten.
if self.flatten_parameters:
# separately flatten trainable and frozen parameters
trainable_params = [p for p in params if p.requires_grad]
frozen_params = [p for p in params if not p.requires_grad]
to_be_flatten_params: List[List[Parameter]] = [trainable_params]
if len(frozen_params) > 0:
non_flatten_params = []
to_be_flatten_params: List[List[Parameter]] = [[]]
non_flatten_params = params
# Here, we don't automatically unflatten XlaFlattenParamsWrapper's state dict
# to avoid overhead on XLA devices. Use ``get_shard_metadata`` to save parameter info
# ``consolidate_sharded_model_checkpoints`` to consolidate the sharded checkpoints.
self._fsdp_wrapped_module: nn.Module = XlaFlattenParamsWrapper(
del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
# params for doing sharding, gradient hooks, etc. Note, the ordering of the
# list matters: flatten params are always in the front.
params_to_shard = cast(
self._fsdp_wrapped_module.flat_params) + non_flatten_params
# Shard module parameters in place
# Cast the module buffers to the specified buffer_dtype
# Make sure all parameters are sharded.
for n, p in self.named_parameters():
assert hasattr(
p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self._require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
# Flag to guard against preparing gradients multiple times per iteration.
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
if execute_sharding_on_init:
# Execute the parameter sharding immediately and free up the memory
if not self.on_neuron:
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)
def set_gradient_divide_factors(self, pre: float, post: float,
recursive: bool) -> None:
Allowing user to override the pre and post divide factors.
pre (float): divide factor before the reduction.
post (float): divide factor after the reduction.
recursive (bool): recursively set it for all child FSDP instances or not.
if recursive:
for module in self.modules():
if isinstance(module, XlaFullyShardedDataParallel) and module != self:
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post
def module(self) -> XlaFlattenParamsWrapper:
"""make model.module accessible, just like DDP."""
assert isinstance(self._fsdp_wrapped_module, XlaFlattenParamsWrapper)
return self._fsdp_wrapped_module
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if p.grad is not None]
def clip_grad_norm_(
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
groups: Optional[List[List[int]]] = None,
) -> torch.Tensor:
Clip all gradients at this point in time. The norm is computed over all
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
groups (list, optional): A list of list, representing the replica
groups for the all-reduce operation to compute global norms.
See `xm.all_reduce` for details.
Total norm of the parameters (viewed as a single vector).
.. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
handles the partitioning and multiple devices per rank under the
hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
max_norm = float(max_norm)
norm_type = float(norm_type)
params_with_grad = self.params_with_grad
# Computes the max norm for this shard's gradients and sync's across workers
local_norm = _calc_grad_norm(params_with_grad, norm_type)
if norm_type == inf:
total_norm = self.all_reduce_op(xm.REDUCE_MAX, local_norm, groups=groups)
total_norm = self.all_reduce_op(
xm.REDUCE_SUM, local_norm**norm_type, groups=groups)
total_norm = total_norm**(1.0 / norm_type)
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7
clip_coef = torch.clip(max_norm / (total_norm + 1e-6), 0.0, 1.0)
for p in params_with_grad:
return total_norm
def _shard_parameters_(self, params_to_shard) -> None:
At initialization we wrap a module with full parameters and shard the
parameters in-place. Sharding is implemented by viewing each parameter
as a 1D Tensor and retaining only a single slice, where the slice size
is determined by the number of data parallel workers.
Wrapping modules with many small parameters (or with a very large data
parallel world size) will result in many small parameter shards and slow
performance. In this case it's better to set *``flatten_parameters``* to
``True``, so that all of the small parameters in the module are combined
into a single contiguous Tensor and sharded once.
After this initial sharding is complete, the user can initialize a
``torch.optim.Optimizer`` in the usual way, i.e.::
.. code-block:: python
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
The optimizer will see only a single slice of parameters and will thus
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
Note: this method is implemented in a different manner from
``fairscale.nn.FullyShardedDataParallel``. Here we delete the original
module parameters and create new sharded parameter tensors (instead of
making sharded tensors an attribute of the original parameters). This
make it easier to handle things (e.g. freeing parameters) on XLA.
if len(params_to_shard) > 0:
# When freeing the full parameters, we point their `.data` to this placeholder
# (so that the XLA compiler can reuse the memory storage).
self._dummy_data_placeholder = torch.zeros(
1, dtype=self.compute_dtype, device=params_to_shard[0].device)
# get the module names of each full parameter to shard
params_to_shard_set = set(params_to_shard)
assert len(params_to_shard_set) == len(params_to_shard), \
"params_to_shard should not have dups"
full_param_infos = []
shared_full_param_memo = {}
shared_full_param_infos = []
full_params = []
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if "xla" not in str(p.device):
raise ValueError(
"please moved the module to XLA device before wrapping with FSDP")
if p.dtype != torch.float32:
raise TypeError("only fp32 parameters are supported")
if p in params_to_shard_set:
if p in shared_full_param_memo:
mname, shared_m, shared_n = shared_full_param_memo[p]
(module_name, mname, m, n, shared_m, shared_n))
shared_full_param_memo[p] = (module_name, m, n)
full_param_infos.append((module_name, m, n))
assert len(full_params) == len(params_to_shard_set), \
f"there are parameters in params_to_shard not belonging to this module."
del shared_full_param_memo
self.full_params = full_params
self.full_param_infos = full_param_infos
self.shared_full_param_infos = shared_full_param_infos
# deregister the full parameter tensors from their modules (so that they won't
# appear in the FSDP model's `parameters()` or `named_parameters()` outputs;
# only the sharded parameters should appear in the FSDP model's `parameters()`)
for _, m, n in self.full_param_infos:
assert n in m._parameters
p = m._parameters.pop(n)
object.__setattr__(m, n, p)
for _, _, m, n, shared_m, shared_n in self.shared_full_param_infos:
assert n in m._parameters
p = m._parameters.pop(n)
object.__setattr__(m, n, p)
# allocate and register new sharded parameters
self.sharded_params = []
for p, (module_name, _, n) in zip(self.full_params, self.full_param_infos):
assert not hasattr(p, "_is_sharded")
shard_data = self._get_shard(
p_shard = nn.Parameter(shard_data, requires_grad=p.requires_grad)
p_shard._is_sharded = True
p_shard._orig_size =
p_shard._orig_name = f"{module_name}.{n}"
p_shard._name = f"_fsdp_shard.{p_shard._orig_name}".replace(
self.register_parameter(p_shard._name, p_shard)
p._sharded_param = p_shard # add a handle to the sharded parameter
# Free the full parameter storage (here we free its `.data`) but keep the tensor itself
# for auto-grad tracing (like `torch.autograd.Variable` before the tensor-variable merge). = self._dummy_data_placeholder
p._has_full_param = False
assert len(self.sharded_params) == len(self.full_params)
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
tensor = self._flatten_and_pad_to_world_size(
tensor, self.world_size * self._shard_size_multiple)
if self.on_neuron:
local_num_row = tensor.shape[0] // self.world_size
begin, end = self.rank * local_num_row, (self.rank + 1) * local_num_row
tensor = tensor[begin:end, ...].clone()
local_numel = tensor.numel() // self.world_size
begin, end = self.rank * local_numel, (self.rank + 1) * local_numel
tensor = tensor[begin:end].clone()
return tensor
def _cast_buffers(self,
dtype: Optional[torch.dtype] = None,
memo: Optional[Set] = None) -> None:
"""Move all buffers to the given *dtype*.
If *dtype* is not given, then it will default to ``self.buffer_dtype``.
In the case of nested FSDP instances, we will respect the child instance's
``buffer_dtype`` configuration.
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, XlaFullyShardedDataParallel):
# Allow any child FSDP instances to handle their own buffers.
module._cast_buffers(dtype=dtype, memo=memo)
elif module not in memo:
for name, buf in module.named_buffers(recurse=False):
if buf is None:
if torch.is_floating_point(buf):
orig_dtype = buf.dtype
cast_dtype = dtype or self.buffer_dtype
if orig_dtype != cast_dtype:
buf =
buf._orig_dtype = orig_dtype
setattr(module, name, buf)
def extra_repr(self) -> str:
repr = (f"world_size={self.world_size}, "
f"rank={self.rank}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"flatten_parameters={self.flatten_parameters}, "
f"reshard_after_forward={self.reshard_after_forward}, "
return repr
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
def no_sync(self) -> Generator:
A context manager to disable gradient synchronizations across FSDP
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass after exiting the context.
.. note:: This likely results in higher memory usage because FSDP will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
.. note:: Gradient accumulation can be done without this context,
avoiding the extra XLA device memory overhead, but with the extra
networking overhead.
assert self._is_root, "no_sync on inner FSDP is not supported"
# This instance may wrap other FSDP instances and we
# need to set all of them to accumulate gradients.
old_flags = []
for m in self.modules(): # includes self
if isinstance(m, XlaFullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
self._all_sharded_params: Optional[Parameter] = None
self._output_pre_backward_hook_registered: Optional[Set] = None
self._backward_opt_barrier_tensors: Optional[List] = None
self._backward_opt_barrier_tensor_ids: Optional[Set] = None
self.reshard_after_forward = self._orig_reshard_after_forward
def _lazy_init(self) -> None:
Initialization steps that should happen lazily, typically right
before the first forward pass.
# Initialize _is_root and setup streams. These steps would ideally
# happen in __init__, but _is_root can only be determined after the
# entire model hierarchy is setup, thus we run it lazily.
if self._is_root is None:
if self._is_root and self.disable_reshard_on_root:
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self.reshard_after_forward = False
def _set_is_root(self) -> None:
If ``True``, implies that no other :class:`XlaFullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
if self._is_root is not None:
# No FSDP instance wraps this, else _is_root would be set to False.
self._is_root = True
self._all_sharded_params = list(self.parameters())
if self._debug_print:
f"root FSDP got {len(self._all_sharded_params)} total params (_debug_msg: {self._debug_msg}).",
# If final backward callback is never been queued, state should be IDLE.
# If final backward callback is queued, the callback should be finished
# and the state was reset to be IDLE.
# This should be asserted at the beginning of forward pass in the root instance only.
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, XlaFullyShardedDataParallel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root
if m._is_root is None:
m._is_root = False
def _setup_output_hook_and_backward_opt_barrier_lists(self) -> None:
Set up a list to avoid registering pre-backward hooks incorrectly.
And a list to apply optimization barrier on backward pass tensors.
assert self._is_root, "This should only be called on the root"
self._output_pre_backward_hook_registered = set()
self._backward_opt_barrier_tensors = []
self._backward_opt_barrier_tensor_ids = set()
for n, m in self.named_modules():
if n != "" and isinstance(m, XlaFullyShardedDataParallel):
m._output_pre_backward_hook_registered = self._output_pre_backward_hook_registered
m._backward_opt_barrier_tensors = self._backward_opt_barrier_tensors
m._backward_opt_barrier_tensor_ids = self._backward_opt_barrier_tensor_ids
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
if self.compute_dtype != torch.float32:
# Cast the input float tensors to the specified compute_dtype
args, kwargs = _cast_floats_tensors(self.compute_dtype, *args, **kwargs)
# All-gather full parameters.
input_opt_barrier_tensors = []
if self.optimization_barrier_in_forward:
# Ensure that previous ops to build this module's inputs (which are
# usually performed in previous modules) are finished before rebuilding
# the full params of this FSDP module.
input_opt_barrier_tensors = collect_tensors((args, kwargs))
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
if not self._debug_dummy_forward_pass:
outputs = self.module(*args, **kwargs)
# Run a dummy forward pass by summing the inputs and full parameter.
# This can be used to debug FSDP parameter memory consumption.
outputs = self._dummy_forward(*args, **kwargs)
if self.reshard_after_forward:
output_opt_barrier_tensors = []
if self.optimization_barrier_in_forward:
# Ensure that the full parameters of this FSDP module are freed
# before any new ops based on this module's outputs (which are usually
# performed in subsequent modules) can happen.
output_opt_barrier_tensors = collect_tensors(outputs)
# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs = self._register_pre_backward_hooks(outputs)
if self.optimization_barrier_in_backward:
# Apply XLA compiler optimization barrier to FSDP outputs and their gradients to avoid
# fusion across FSDP modules (which sometimes results in higher memory consumption).
input_grad_opt_barrier_tensors = input_opt_barrier_tensors or collect_tensors(
(args, kwargs))
# Done with a forward pass.
self.training_state = TrainingState.IDLE
return outputs
def _dummy_forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
A dummy forward passs with minimal computation that sums all inputs and
full parameters, e.g. to debug parameter memory consumption.
outputs = torch.zeros(1, device=xm.xla_device())
for t in chain(args, kwargs.values(), self.full_params):
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
outputs = outputs + t.mean()
# recursively run dummy forward pass on inner FSDP modules (if any)
resursive = kwargs.pop("_xla_fsdp_dummy_forward_resursive", True)
if resursive:
assert self._is_root
for m in self.modules():
if isinstance(m, XlaFullyShardedDataParallel) and m != self:
_m_orig_debug_dummy_forward_pass = m._debug_dummy_forward_pass
m._debug_dummy_forward_pass = True
outputs = m(outputs, _xla_fsdp_dummy_forward_resursive=False)
m._debug_dummy_forward_pass = _m_orig_debug_dummy_forward_pass
return outputs
def _try_adding_to_backward_opt_barrier_lists(self,
tensor: torch.Tensor) -> None:
Add tensor to backward pass optimization barrier list if it is not there.
if id(tensor) not in self._backward_opt_barrier_tensor_ids:
def _clear_backward_opt_barrier_lists(self) -> None:
"""Reset the backward pass optimization barrier list"""
def _register_grad_opt_barrier_hooks(
self, dependency_tensors: List[torch.Tensor]) -> None:
Register hook to `dependency_tensors` to put their gradient tensors into
self._backward_opt_barrier_tensors for backward pass optimization barrer.
if not torch.is_grad_enabled():
return # don't register hooks if grad isn't enabled
def _grad_opt_barrier_hook(t_grad: torch.Tensor):
return t_grad.view(t_grad.size()) # a view with barrier applied
for t in dependency_tensors:
if t.requires_grad:
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
outputs: new outputs with hooks registered if they requires gradient.
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has
# _post_backward_callback_queued defined. Accidentally accessing this field
# will assert on all other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
def _pre_backward_hook(t_grad: torch.Tensor) -> None:
# try to queue final backward callback only once for root, so
# that final backward callback is attached to the outer most
# backward graph task and called after all the backward
# calls are completed.
if self._is_root:
if self.optimization_barrier_in_backward:
# All-gather full parameters or switching to the full params.
# Note, ``self._rebuild_full_params`` is idempotent. So in case it is called
# unnecessarily, it doesn't incur much overhead.
if self.reshard_after_forward:
dependency_tensors = []
if self.optimization_barrier_in_backward:
# Ensure that backward pass ops of feature gradients, parameter
# gradient and sharding, and full-param freeing (which are usually
# performed in previous modules and are registered to
# self._backward_opt_barrier_tensors in _grad_opt_barrier_hook,
# _pre_backward_hook, and _post_backward_hook) are finished before
# rebuilding the full params of this FSDP module.
dependency_tensors = self._backward_opt_barrier_tensors
# Only run the following once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes).
if not self._pre_backward_hook_has_run:
self._pre_backward_hook_has_run = True
# Start of a backward pass for the first time in an iteration.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
# Check p.grad to make sure that it is in the right shape, device, etc.
for p, p_shard in zip(self.full_params, self.sharded_params):
if p.grad is not None:
assert p.grad.device == p_shard.device
assert p.grad.size() == p_shard._orig_size
# Transition to BACKWARD_PRE state if currently IDLE. We can transition from BACKWARD_POST
# to IDLE when FSDP is within activation checkpointing and called multiple times, due to the
# extra forward pass for re-computation.
if self.training_state == TrainingState.IDLE:
self.training_state = TrainingState.BACKWARD_PRE
[TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
if self.optimization_barrier_in_backward:
t_grad = t_grad.view(t_grad.size()) # a view with barrier applied
return t_grad
_registered = 0
def _register_hook(t: torch.Tensor) -> torch.Tensor:
# We don't register the pre_backward hook on the same tensor that has been
# returned from an inner FSDP, unless it is the first one.
nonlocal _registered
assert self._output_pre_backward_hook_registered is not None
if t.requires_grad and (_registered == 0 or id(t)
not in self._output_pre_backward_hook_registered):
_registered += 1
return t
# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(_register_hook, outputs)
return outputs
def _register_post_backward_hooks(self) -> None:
Register backward hooks to reshard params and reduce-scatter grads.
This is called during forward pass. The goal is to attach a hook
on each of the parameter's gradient generating function (``grad_acc``
below) so that the hook is called *after* all gradients for that
param are computed.
1. We want the hook to fire once and only once *after* all gradients
are accumulated for a param.
2. If it fires more than once, we end up incorrectly shard the grad
multiple times. (could lead to dimension too small)
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self.full_params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][
0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(
functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will replace
``param.grad`` with a single shard of the summed gradient across all
XLA devices. This shard will align with the current rank. For example::
before reduce_scatter:
param.grad (rank #0): [1, 2, 3, 4]
param.grad (rank #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (rank #0): [6, 8] # 1+5, 2+6
param.grad (rank #1): [10, 12] # 3+7, 4+8
The local XLA device's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current XLA device's rank. This
alignment is created by :func:`_shard_parameters_`, which ensures that
the local optimizer only sees the relevant parameter shard.
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
assert param.grad is not None, param.shape
if param.grad.requires_grad:
raise RuntimeError(
"FSDP only works with gradients that don't require gradients")
grad =
if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more XLA device memory.
if not self._require_backward_grad_sync:
if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
# Shard the gradients with `reduce_scatter`.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
param.grad = None
grad_flat = self._flatten_and_pad_to_world_size(
grad, self.world_size * self._shard_size_multiple)
if self.optimization_barrier_in_backward:
if grad_flat.dtype != torch.float32 and self.fp32_reduce_scatter:
grad_flat =
reduced_grad = self.reduce_scatter_op(
if reduced_grad.dtype != torch.float32:
reduced_grad =
if self.optimization_barrier_in_backward:
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
grad._has_full_param = True
grad_flat._has_full_param = True
[grad, grad_flat],
# Accumulate into the gradient shard.
assert hasattr(param, "_sharded_param")
p_shard = param._sharded_param
if p_shard.grad is None:
p_shard.grad =
assert p_shard.grad.shape == reduced_grad.shape
assert p_shard.grad.device == reduced_grad.device +=
def _queue_wait_for_post_backward(self) -> None:
Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback at the beginning of
outer most backward.
assert self._is_root
if not self._post_backward_callback_queued:
self._post_backward_callback_queued = True
def _try_wait_for_post_backward(self) -> None:
Catch and print any exception in `_wait_for_post_backward`. Otherwise the
exception is not printed and error is very confusing as shown below.
built-in method run_backward of torch._C._EngineBase object at 0x7f26dc335aa0> returned NULL without setting an error
except Exception as e:
f"Exception below occurred in post-backward (_debug_msg: {self._debug_msg}). "
f"This is often due to some inner FSDP modules not being used "
f"in an outer FSDP module's forward pass. Please make sure that all inner "
f"FSDP modules participate in the forward pass when using nested FSDP.\n"
f"{type(e).__name__}: {e}")
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward to finish. Only called on root instance."""
assert self._is_root
# Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.full_params]):
# A backward pass is done, clean up below.
def _finalize_parameters(fsdp_module: XlaFullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.full_params:
if not p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
delattr(p, "_shard_bwd_hook")
# Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self
if isinstance(m, XlaFullyShardedDataParallel):
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.full_params]):
# When `m` and its children has no params or has params but
# none with `requires_grad==True`, there are two cases:
# 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is still registered, so it is in BACKWARD_PRE state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
m.training_state = TrainingState.IDLE
if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
# clear this list for next iteration
assert self._output_pre_backward_hook_registered is not None
if self.optimization_barrier_in_backward:
# Ensure that backward pass ops of feature gradients, parameter
# gradient and sharding, and full-param freeing (which are usually
# performed in previous modules and are registered to
# self._backward_opt_barrier_tensors in _grad_opt_barrier_hook,
# _pre_backward_hook, and _post_backward_hook) are finished before
# accessing the sharded gradients of this FSDP module.
params_with_grad = [
p for p in self._all_sharded_params if p.grad is not None
params_data = [ for p in params_with_grad]
grad_data = [ for p in params_with_grad]
dependency_tensors = params_data + grad_data
for p, p_data, g_data in zip(params_with_grad, params_data,
grad_data): = p_data = g_data
if self.mark_step_on_finalization:
# Forcing an execution at the end of backward pass to avoid any XLA compiler
# fusion between backward and optimizer (e.g. AdamW and SGD) step.
# Here `xm.mark_step` is only called once for the entire backward pass and
# should therefore only moderately increase the execution time.
# It may help prevent undesired fusion in backward pass and save more memory.
if self._debug_print:
f"mark_step called in FSDP _wait_for_post_backward (_debug_msg: {self._debug_msg})",
def _rebuild_full_params(self,
dependency_tensors: Optional[List[
torch.Tensor]] = None,
apply_opt_barrier: bool = True) -> None:
Gather all shards of params. If `dependency_tensors` is provided,
it ensures that previous ops to compute tensors in `dependency_tensors`
are finished before rebuiding the full parameters.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
if self.has_full_params:
if dependency_tensors is None:
dependency_tensors = []
if apply_opt_barrier:
for p, p_shard in zip(self.full_params, self.sharded_params):
if not p._has_full_param:
p_shard_data = p_shard.detach()
if apply_opt_barrier:
if p_shard_data.dtype != self.compute_dtype:
p_shard_data =
# gather full parameter from shards
if self.on_neuron:
p_padded = self.all_gather_op(
p_shard_data, groups=self.sharding_groups)
if apply_opt_barrier:
assert p_padded.shape == p_shard._orig_size # TODO: pad for world size = p_padded
# reshape sharded parameters to 2d tensors for efficient gathering on
# TPUs (see for details).
p_shard_2d = p_shard_data.view(-1, self._shard_size_multiple)
p_padded = self.all_gather_op(
p_shard_2d, groups=self.sharding_groups).flatten()
if apply_opt_barrier:
self.optimization_barrier_op([p_padded]) = p_padded[:p_shard._orig_size.numel()].view(p_shard._orig_size)
p._has_full_param = True
self.has_full_params = True
def _free_full_params(self,
params: Optional[List[Parameter]] = None,
dependency_tensors: Optional[List[torch.Tensor]] = None,
apply_opt_barrier: bool = True) -> None:
Free up storage for full parameters. If `dependency_tensors` is provided,
it ensures that the full parameters are freed before any new ops that
depend on tensors in `dependency_tensors` can be executed.
if params is None:
full_params = self.full_params
sharded_params = self.sharded_params
full_params = params
sharded_params = [
p._sharded_param for p in params if hasattr(p, "_sharded_param")
if dependency_tensors is None:
dependency_tensors = []
self.has_full_params = False
for p in full_params:
if p._has_full_param:
# free the original full parameter = self._dummy_data_placeholder
p._has_full_param = False
if apply_opt_barrier:
self._apply_opt_barrier_to_params_and_tensors(full_params, sharded_params,
def _apply_opt_barrier_to_params_and_tensors(
self, p_list: List[torch.Tensor], p_shard_list: List[torch.Tensor],
dependency_tensors: List[torch.Tensor]):
Apply XLA compiler optimization barrier to full and shared parameters
and other dependency tensors. This is to avoid fusion of the full
parameter rebuilding and freeing with other computation.
Otherwise, the XLA compiler might fuse `_rebuild_full_params` and
`_free_full_params` in the forward pass with any of these calls in the
backward pass through common subexpression elimination (CSE) and keep the
full parameters (not freeing them and rebuilding them later, essentially
changing `reshard_after_forward` to `False` and using more memory).
This method also introduce control dependency on `dependency_tensors`, so
that all tensors in `dependency_tensors` must be evaluated before any new
computation on the full or sharded parameters or `dependency_tensors` can
if len(p_list) + len(p_shard_list) + len(dependency_tensors) == 0:
p_data_list = [ for p in p_list]
p_shared_data_list = [ for p_shard in p_shard_list]
self.optimization_barrier_op(p_data_list + p_shared_data_list +
for p, p_data in zip(p_list, p_data_list): = p_data
for p_shard, p_shard_data in zip(p_shard_list, p_shared_data_list): = p_shard_data
def assert_state(self, state: Union[TrainingState,
List[TrainingState]]) -> None:
"""Assert we are in the given state."""
# Since assert can be turned off and this error checking
# is really important, we use explicit error checking
# and raise a ValueError if needed.
if isinstance(state, TrainingState):
state = [state]
if self.training_state not in state:
msg = f"expected to be in states {state} but current state " f"is {self.training_state}"
# In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure.
if self.rank == 0:
print(f"Asserting FSDP instance is: {self}")
print(f"ERROR: {msg}")
raise ValueError(msg)
def get_original_names_and_sharded_parameters(self):
Get the sharded parameters along with their original names. Its output is similar to
``named_parameters`` but contains sharded (and flattened) parameters.
orig_named_parameters = []
for module_name, m in self.named_modules(): # includes self
if isinstance(m, XlaFullyShardedDataParallel):
prefix = "" if module_name == "" else module_name + "."
for p in self.sharded_params:
n = prefix + p._orig_name
n = n.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "")
orig_named_parameters.append((n, p))
return orig_named_parameters
def get_shard_metadata(self):
Get the shard metadata to consolidate the sharded model checkpoints.
The output from this method should be saved in a checkpoint file and
used in ``consolidate_sharded_model_checkpoints``.
shard_info = {}
flatten_info = {}
buffer_info = {}
for module_name, m in self.named_modules(): # includes self
# remove "_fpw_module." from module names since it is also removed in
# XlaFullyShardedDataParallel's state_dict()
module_name = module_name.replace("_fpw_module.", "")
if isinstance(m, XlaFullyShardedDataParallel):
sharded_param_info = {}
for p_shard in m.sharded_params:
sharded_param_info[p_shard._name] = {
"_orig_size": p_shard._orig_size,
"_orig_name": p_shard._orig_name,
shard_info[module_name] = sharded_param_info
if isinstance(m, XlaFlattenParamsWrapper):
for i in range(len(m.flat_params)):
param_name = f"flat_param_{i}"
if module_name != "":
param_name = module_name + "." + param_name
flatten_info[param_name] = m.metadata(i)
for name, buf in self.named_buffers():
if buf is not None and hasattr(buf, "_orig_dtype"):
buffer_info[name] = {"_orig_dtype": buf._orig_dtype}
metadata = {
"shard_info": shard_info,
"flatten_info": flatten_info,
"buffer_info": buffer_info,
"world_size": self.world_size,
"rank": self.rank,
return metadata
def _print_r0(self, msg: str, restart: bool = False) -> None:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if restart:
self._tstart = time.time()
if self.rank == 0:
memory_info = xm.get_memory_info(xm.xla_device())
gb_free = memory_info["kb_free"] / 1024 / 1024
gb_total = memory_info["kb_total"] / 1024 / 1024
f"{msg} free={gb_free: .4f} GB, total={gb_total: .4f} GB, t={time.time()-self._tstart: .1f}"
def _flatten_and_pad_to_world_size(self, tensor: torch.Tensor,
world_size: int) -> torch.Tensor:
"""Flatten and pad a tensor to a given world size (for reduce-scatter)."""
if self.on_neuron:
assert tensor.shape[0] % self.world_size == 0 # TODO: pad for world size
tensor = tensor.flatten()
if tensor.numel() % world_size != 0:
pad_size = world_size - tensor.numel() % world_size
tensor = F.pad(tensor, [0, pad_size])
return tensor
def apply_to_tensors(
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple,
Set]) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, PackedSequence):
return x
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
return x
return _apply(container)
def collect_tensors(
container: Union[torch.Tensor, Dict, List, Tuple,
Set]) -> List[torch.Tensor]:
"""Recursively collect to all tensor in different kinds of container types."""
def _collect(x, out, out_ids) -> None:
if torch.is_tensor(x):
if id(x) not in out_ids:
elif isinstance(x, PackedSequence):
_collect(, out, out_ids)
elif isinstance(x, dict) or isinstance(x, OrderedDict):
for value in x.values():
_collect(value, out, out_ids)
elif isinstance(x, list) or isinstance(x, tuple) or isinstance(x, set):
for value in x:
_collect(value, out, out_ids)
tensors = []
_collect(container, tensors, set())
return tensors
def _calc_grad_norm(parameters: List[torch.nn.Parameter],
p: float) -> torch.Tensor:
Calculate gradient norm of an iterable of parameters.
Total norm of the parameters (viewed as a single vector).
if len(parameters) == 0:
return torch.tensor(0.0)
if p == inf:
local_norm = max(par.grad.detach().abs().max() for par in parameters)
local_norm = torch.norm(
torch.stack([torch.norm(par.grad.detach(), p) for par in parameters]),
return local_norm
def _cast_floats_tensors(dtype: torch.dtype, *args: Any,
**kwargs: Any) -> Tuple[Any, Any]:
Cast floating point Tensors in *args or **kwargs to dtype if they are not.
def fn(t):
if t.dtype != dtype and torch.is_floating_point(t):
t =
return t
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
