Skip to content

Instantly share code, notes, and snippets.

@mcarilli
Last active January 18, 2023 03:21
Show Gist options
  • Save mcarilli/43445260404c8d7cd79d84439808250e to your computer and use it in GitHub Desktop.
Save mcarilli/43445260404c8d7cd79d84439808250e to your computer and use it in GitHub Desktop.
Automatic mixed precision for Pytorch: supplementary information

Big-picture concerns:

  • Movement between autocasting-enabled and autocasting-disabled regions
  • Nesting FP32-enforced regions within autocasting regions
  • Extension kernels, which might not route through Pytorch dispatch at all
  • Avoid routing through an Amp dispatch layer wherever we can help it, because it incurs an extra lap through the dispatcher
  • Avoid the Amp dispatch layer handling all functions (we should only need a subset)
  • All operations that take multiple inputs and are not safe to run with different precisions among inputs (ie, do not support builtin type promotion) must receive an explicit Amp backend function

All of the UXs below are possible with the current dispatch. 1 and 2 would be simpler to implement than 3 and 4.

Approach 1: Just a Context Manager

An amp.autocast context manager flips a global flag that controls whether or not ops route through an Amp dispatch layer. Tensors themselves are not given any special additional identity.

with amp.autocast():
    ops... # Safe to enter autocast-enabled region.
    with amp.autocast(enabled=False):
        ops... # Ops here may need to deal with a mixture of float and half tensors, and require manual casts to float.
               # Type promotion will smooth over some of these.
    ops... # Safe to reenter autocast-enabled region.
ops... # Ops here have to deal with a mixture of float and half tensors created under the context manager.
       # Errors will crop up one by one and require a manual float conversion in each case.  The errors will be clear
       # and easy to find on a per-op basis, though.  With type promotion, there may not even be that many.

Advantages:

  • Manual control
  • Ops that don't need Amp special treatment can "fall through" to the next step along the dispatch path (autograd history recording in VariableType*.cpp, most likely), saving 1 round trip through the dispatch machinery. Ed has not implemented the fallthrough yet, but he is pushing for the idea (pytorch/pytorch#28386).

Disadvantages:

  • Bleedover of tensors created with different types from autocasting-enabled regions to autocasting-disabled regions. People may have to insert manual casts. The places these manual casts must go will be easy to find, and minimized by kernels supporting type promotion. We don't know for certain how common/annoying this will be for typical networks. This could be regarded as a "documentation problem."
  • backward() should not be under the context manager, which is a gotcha people may easily run into.

Approach 2: Just Decorators (with context manager as implementation detail)

MyModule(torch.nn.Module):
   @amp.float
   def my_float_func(args...):
       ops...

   @amp.half # Maybe this one should not exist at all.
   def my_half_func(args...):
       ops...

   @amp.autocast
   def forward(args...):
       ops...
       
# amp.autocast would look like:
def autocast(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        already_autocasting = amp.is_autocasting()
        if already_autocasting:
            return func(*args, **kwargs) # Go with the flow
        else:
            with amp._autocast(): # 
                return cast_to_float(func(*args, **kwargs)) # Cast the output to float
    return wrapper

# amp.float would look like
def float(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        with amp._autocast(enabled=False): # Disable autocasting
            return func(cast_to_float(*args), cast_to_float(**kwargs))
    return wrapper

Advantages:

  • Simplicity of use for common cases
  • Since we are wrapping things that have explicitly-known inputs and outputs, we can minimize bleedover of tensors created with different types from autocasting-enabled to autocasting-disabled regions. We can ensure that outputs are always cast to float when exiting an autocasting region.
  • The danger of running the backward pass with autocasting enabled is reduced.
  • Fallthrough in the backend is just as viable as for the raw context manager.
  • If the JIT script parsing can parse decoration statements directly, as opposed to parsing the expanded form that may include the with statement, maybe we won't need with statement support in the JIT to make this API work. Someone could say
    @torch.jit.script
    @amp.autocast
    def my_jit_autocast_function...
    

Disadvantages:

  • There is still danger of bleedover, if @amp.float functions use data that doesn't come in through the argument list, or @amp.autocast functions create tensors that are supplied to the outside world by some means other than the function outputs. This could be regarded as a "documentation problem."
  • The granularity at which regions can be autocasted must coincide with functions. If people require finer granularity, they must either use the implementation-detail context manager to enable/disable that region (with all the bleedover that implies) or write a new function to encapsulate the desired region, and decorate that function.

Approach 3: AmpTensors

AmpTensors become a unique datatype, as in the msnpu test (https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions.py#L739). We can decide whether they are exposed as a separate device or Ampness is a separate tensor axis.

UX: model.amp().

All operations where at least 1 tensor is an AmpTensor will route through the Amp dispatch layer, and be subject to autocasting. These operations will output Tensors that are also AmpTensors. AmpTensor identity would be independent from floatness or halfness. Autocasting and Ampness would occur and propagate to wherever AmpTensors were used, just like grad history recording and requires_gradness.

Autocasting would not be triggered for extension calls. Extension ops would need to cast their inputs to float manually, whether in an an autocasting-enabled or disabled region. We should probably supply a decorator to help. We would also need a context manager or decorator for nested FP32-enforced regions.

Fallthrough would still work.

Approach 4: Using AmpTensors to augment Context Manager/Decorator approach

AmpTensors could facilitate "self-cleaning" context managers/decorators.

The context managers control a global flag as usual. Under context managers, whitelist/blacklist ops always dispatch through an Amp layer, AND return Tensors that have been given AmpTensor identity (or maybe only HalfTensors would need to be given AmpTensor identity, because they're the only ones that would need to be cleaned up?). After context manager exit, the global flag is False, but any operation with at least one AmpTensor among the inputs will still route through the Amp dispatch layer. The Amp function, seeing that the global autocasting flag is False, will realize its autocasting shenanigans are no longer welcome, cast any AmpTensor arguments to float, run the op, and return ordinary float Tensors that do not have Amp identity.

Autocasting and "self-cleaning" would not be triggered for custom/extension ops. Extension ops would need to cast their inputs to float manually, whether in an an autocasting-enabled or disabled region. We should probably supply a decorator to help.

Typical closure invocation (without gradient scaling) looks like

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    loss = optimizer.step(closure)

The challenge is that step(closure) internally runs the closure to compute gradients then immediately applies them to the parameters. We need some way to check for infs/nans, and avoid applying inf/nan gradients to parameters. However, we can't touch the source code of optimizer.step() because that is the Custom Optimizer Pitfall, which means we can't interpose inf/nan checking code between the closure invocation and the param updates in the body of step().

The API can ask the user to change the body of the closure, and to change the optimizer.step(closure) to optimizer.unscale_and_step(closure) or optimizer.step_after_unscale(closure). We also have control over the body of unscale_and_step and step_after_unscale, although those functions must eventually call the underlying, unmodified self.step(closure).

One solution would be for unscale_and_step to wrap the closure in another function that checks for infs/nans and replays the closure with a reduced loss scale if infs/nans are found. The wrapped closure would be guaranteed to produce non-inf/nan gradients, so it could safely be passed along to self.step. On the user side, that might look like

for input, target in dataset:
    def closure(S):
        # ^ the user augments the closure definition to accept the scale, so it can be modified by the wrapper code.
        # Python references bind to the name in the nearest scope that surrounds their textual definition,
        # rather than the scope(s) that surround(s) their actual point-of-use:
        # https://docs.python.org/3/tutorial/classes.html#python-scopes-and-namespaces
        # Since "closure" is def-ed in the training loop, references to "S" in the closure will bind
        # to the training-loop value of S.  If the wrapper internally redefines S then runs the closure,
        # the wrapper's redefinition will not be picked up by S within the closure.
        # def wrapper(closure):
        #     S = new_scale
        #     closure() # S within the closure will still refer to whatever value is in scope in the training loop, NOT new_scale.
        # To give the wrapper the ability to modify S and rerun the closure, we request that the closure take S as an argument.
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        scale_outputs(loss, S).backward()
        return loss
    loss, found_inf, S = optimizer.unscale_and_step(closure, grads_scaled_by=S)
    # ^found_inf will always contain False, because the closure is internally replayed until no infs/nans are found.

and unscale_and_step itself might look like

def unscale_and_step(self, closure=None, grads_scaled_by=None)
    if closure is not None:
       found_inf = None
       S = grads_scaled_by
       def safe_closure_wrapper():
           nonlocal found_inf, S # Tell safe_closure_wrapper to reference the values defined above.  Could also stash onto self.
           loss = closure(S)
           found_inf, S = self.unscale(grads_scaled_by=S)
           while found_inf.item():
               loss = closure(S)
               found_inf, S = self.unscale(grads_scaled_by=S)
           return loss
        loss = self.step(safe_closure_wrapper)
        return loss, found_inf, S
    else:
        ...

If someone wants to unscale within the closure, the situation becomes uglier because found_inf and S will be created within the closure itself, and must be communicated back to the wrapper somehow. The simplest approach will be to require that if the user unscales within the closure, the closure should return the found_inf and scale returned by unscale. On the user side, that would look like

for input, target in dataset:
    def closure(S):
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        scale_outputs(loss, S).backward()
        found_inf, S = self.unscale(grads_scaled_by=S)
        # manipulate gradients as desired
        return loss, found_inf, S
    loss, found_inf, S = optimizer.step_after_unscale(closure, current_grad_scale=S)

and step_after_unscale might look like

def step_after_unscale(self, closure=None, found_inf=None, current_grad_scale=None)
    if closure is not None:
       found_inf = None
       S = current_grad_scale
       def safe_closure_wrapper():
           nonlocal found_inf, S # could also stash onto self
           loss, found_inf, S = closure(S)
           while found_inf.item():
               loss, found_inf, S = closure(S)
           return loss # The wrapper only returns loss, because that's what self.step() internally expects
        loss = self.step(safe_closure_wrapper)
        return loss, found_inf, S
    else:
        ...

I'm not a huge fan of this because it requires the addition of the current_grad_scale argument to step_after_unscale, but it works.

Alternatives to amp.scale_outputs

The scale factor may be applied by directly multiplying the loss/outputs, directly multiplying any manually-fed gradients, or by a keyword argument to the backward call (see examples below).

Example Backward Invocations

Scaling gradients for ordinary backward:

(loss*scale).backward() # or
loss.backward(scale_grads_by=scale)

Scaling gradients with torch.autograd.backward:

torch.autograd.backward((output0*scale, output1*scale)) # or
torch.autograd.backward((output0, output1), grad_tensors=(grad0*scale, grad1*scale)) # or
torch.autograd.backward((output0, output1), grad_tensors=(grad0, grad1), scale_grads_by=scale)

Scaling gradients with torch.autograd.grad:

torch.autograd.grad((output0*scale, output1*scale), model.parameters()) # or
torch.autograd.grad((output0, output1), model.parameters(), grad_outputs=(grad0*scale, grad1*scale)) # or
torch.autograd.grad((output0, output1), model.parameters(), scale_grads_by=scale)

Note that for the backward/grad calls, there is no proposed change to the existing API aside from the additional keyword argument scale_grads_by=scale, which would be for convenience/syntactic sugar. If scale_grads_by=scale is supplied, a call to backward or grad would immediately (on the Python side) scale the gradient(s) or output(s) by scale before handing them to Variable._execution_engine.run_backward.

scale_grads_by may also be None, in which case it will be ignored. This possibility allows users to easily switch gradient scaling on and off by globally setting scale to None (see unscale_and_step Implementation in the next section, and Switching automatic mixed precision on and off under End to End Examples).

As shown, omitting scale_grads_by and manually multiplying the outputs or gradients by scale will also be a valid approach. However, in this case it won't be as convenient to globally switch gradient scaling on and off.

I see two options for making wrapper methods available to the optimizer instance:

  • Passing each optimizer through some initial script-side API call, e.g. torch.amp.patch_optimizer(optimizer), that patches them on using types.MethodType.
  • Extending torch.optim.Optimizer with the methods, which will be inherited by Torch optimizers and user-defined optimizers.

Patching and inheritance both avoid the Custom Optimizer Pitfall. They transparently enable mixed precision support for existing custom optimizers without the authors needing to alter their step() methods to handle a new argument.

Note: For all implementations below, step_after_unscale() could instead be step_after_unscale(found_inf) so that unscale does not need to set the inf state internally.

Patching

If we choose to add the methods via patching, the user must pass each optimizer through a patching call (e.g., torch.amp.patch_optimizer(opt)) before training. Here are the possible UXs that I see:

  1. torch.amp.patch_optimizer adds unscale_and_step, unscale, and step_after_unscale methods:

    torch.amp.patch_optimizer(opt)
    ...
    # typical use
    _, found_inf, recommended_scale  = opt.unscale_and_step(grads_scaled_by=scale)
    # use with separate unscaling and stepping
    found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
    opt.step_after_unscale() # consumes and handles an inf state set by unscale

    Existing optimizers won't need to define unscale_and_step or step_after_unscale methods. However, if users want more control, they may define those methods on their custom optimizers. torch.amp.patch_optimizer(opt) will include a hasattr check for unscale_and_step and step_after_unscale. If opt has those methods already defined, torch.amp.patch_optimizer will leave them undisturbed.

  2. torch.amp.patch_optimizer adds unscale_and_step and unscale methods, but not step_after_unscale:

    torch.amp.patch_optimizer(opt)
    ...
    # typical use
    _, found_inf, recommended_scale  = opt.unscale_and_step(grads_scaled_by=scale)
    # use with separate unscaling and stepping
    found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
    if not found_inf.item(): # Manual inf/nan check
        opt.step() # Ordinary step()

    Again, custom optimizers may choose to define their own unscale_and_step method, in which case torch.amp.patch_optimizer won't touch it.

Between 1. and 2. I like 1. better, because it gives custom optimizers the option to take full control of step_after_unscale as well. For example, we'd like to implement optimizers that carry out sync-free dynamic loss scaling, where found_inf's data_ptr is passed directly to an optimizer kernel that internally decides to update the param values or not. If a user script looks like this

if not found_inf.item(): # Manual inf/nan check
    opt.step()

the item() always incurs a host-device sync, and there's nothing a custom optimizer writer can change in their step() method to avoid the sync.

Sync-free dynamic loss scaling is just an example. In general, I like an API that provides custom optimizers with both unscale_and_step and step_after_unscale control points. I don't want to paint custom optimizers into any corners.

  1. torch.amp.patch_optimizer patches step itself with a thin wrapper that handles a grads_scaled_by kwarg, and with logic to check if a previous unscale call found and inf:

    torch.amp.patch_optimizer(opt)
    ...
    # typical use
    _, found_inf, recommended_scale  = opt.step(grads_scaled_by=scale)
    # use with separate unscaling and stepping
    found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
    opt.step() # step() has been patched to check if unscale has set an inf state
  2. torch.amp.patch_optimizer patches step itself with a thin wrapper that handles a grads_scaled_by kwarg, but defers inf/nan checking to the user in case of separate unscaling:

    torch.amp.patch_optimizer(opt)
    ...
    # typical use
    _, found_inf, recommended_scale  = opt.step(grads_scaled_by=scale)
    # use with separate unscaling and stepping
    found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
    if not found_inf.item():
        opt.step()

I'm mostly mentioning 3. and 4. for completeness. I don't like them.

Since opt always has a step method defined, torch.amp.patch_optimizer has no way of knowing to bail out and not touch opt.step if step contains its own custom logic to handle unscaling and inf checking. When swapping in such a custom optimizer, to leave step untouched and let the optimizer's custom logic operate, the user would also have to comment out torch.amp.patch_optimizer in their script. This is confusing relative to APIs 1. and 2. that require scripts to call unscale_and_step/step_after_unscale, for which custom optimizers that define these methods can be swapped in and will work as their authors intended, without the user script needing to comment out torch.amp.patch_optimizer.

Inheritance

Instead of patching the thin-wrapper methods onto the opt instance, we may add default implementations of these methods to the base class torch.optim.Optimizer, so they'll be inherited and user scripts won't need to call torch.amp.patch_optimizer(opt). Here are the possible UXs that I see:

  1. torch.optim.Optimizer provides unscale_and_step, unscale, and step_after_unscale methods:

    # torch.amp.patch_optimizer(opt) not needed
    # typical use
    _, found_inf, recommended_scale  = opt.unscale_and_step(grads_scaled_by=scale)
    # use with separate unscaling and stepping
    found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
    opt.step_after_unscale() # consumes and handles an inf state set by unscale
  2. torch.optim.Optimizer provides unscale_and_step and unscale, but not step_after_unscale:

    # torch.amp.patch_optimizer(opt) not needed
    # typical use
    _, found_inf, recommended_scale  = opt.unscale_and_step(grads_scaled_by=scale)
    # use with separate unscaling and stepping
    found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
    if not found_inf.item(): # Manual inf/nan check
        opt.step() # Ordinary step()

As with my preference for 1. over 2, I prefer 5. over 6. because it gives custom optimizers the chance to take full control of step_after_unscale.

Also, again, existing optimizers won't need to define unscale_and_step or step_after_unscale methods. Their existing step() methods will be sufficient. However, if users want to take more control, overriding unscale_and_step or step_after_unscale in derived optimizers will be acceptable.

TL;DR

Overall, my preference is for 5., which uses inheritance.

# torch.amp.patch_optimizer(opt) not needed
# typical use
_, found_inf, recommended_scale  = opt.unscale_and_step(grads_scaled_by=scale)
# use with separate unscaling and stepping
found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
opt.step_after_unscale() # consumes and handles an inf state set by unscale

I believe 5. offers a convenient UX in the common case, as well as the control points some crazy/speed-of-light-seeking custom optimizer author would want, with the fewest training-script-side lines changed.

However, 1., which uses patching, offers custom optimizer writers the same control points without extending torch.optim.Optimizer, at the cost of an additional line of user code.

torch.amp.patch_optimizer(opt)
...
# typical use
_, found_inf, recommended_scale  = opt.unscale_and_step(grads_scaled_by=scale)
# use with separate unscaling and stepping
found_inf, recommended_scale = opt.unscale(grads_scaled_by=scale)
opt.step_after_unscale() # consumes and handles an inf state set by unscale

Us having the patch_amp_method(opt) control point in user scripts may also prove useful in the future for unforeseen reasons.

I like both 1. and 5.

Why patching or inheritance, and not wrapping the optimizer instance in a wrapper class?

In principle, instead of using patching or inheritance, we could wrap the optimizer instance in a wrapper class that provides the step-wrapping unscale_and_step, etc. methods we need. However, wrapping the optimizer instance hides top-level attributes. DistributedDataParallel wraps a model rather than patching onto it, and people are forced to write irritating control flow like

if distributed:
    model.module.attr
else:
    model.attr

We could implement a custom getattr for the wrapper class that forwards attribute accesses to the wrapped optimizer instance, but that seems more hacky than just adding methods to the instance directly (via either inheritance or patching).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment