Skip to content

Instantly share code, notes, and snippets.

@apaszke
Last active September 7, 2017 02:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save apaszke/8959cf0eaf468eaeb2f9e9b84c81ea69 to your computer and use it in GitHub Desktop.
Save apaszke/8959cf0eaf468eaeb2f9e9b84c81ea69 to your computer and use it in GitHub Desktop.

Weight format proposal

Certain backends have aditional requirements for the weights, and it would be good if we could somehow transform the weights into this format once, and reuse it accross many iterations.

Currently we have two modules that could already benefit from these changes:

  • cuDNN RNN could format the weights into a single contiguous block of memory
  • BatchNorm could maintain fp32 running averages to stabilize fp16 training

API:

Each backend-specific format is described by two methods:

from torch.nn.param_formats import ParamFormat, register_format

# each class would have a registry of available formats
@register_format(nn._RNNBase, 'cudnn')
class cudnnRNNParamFormat(ParamFormat):
  @staticmethod
  def to_format(module):
      # check that module is on CUDA
      # compact the weights
      # NOTE: this can also replace `forward` so that it calls a different autograd Function
      
  @staticmetod
  def from_format(module):
      # restore module to original setting
      # wouldn't necessarily need to remap the weights,
      # but has to make params separate Variables
      

Usage:

x = nn.LSTM(...).cuda()
x.set_weight_format('cudnn')    # changes params to contain a single flat weight vector
assert len(x.parameters()) == 1 
x.restore_weights()             # reverts set_weight_format
assert len(x.parameters()) == 4

Important issues

  • Should we serialize weights in standard or locked format?
    • In standard format - cuDNN doesn't even define its RNN weight format, and it varies between versions, so there's no way to correctly convert them back on another machine.
    • Serialization shouldn't remap storages, so we can't temporarily switch the module into standard format. We should create a shallow clone, convert it, and serialize that.
  • Should we restore weight format when deserializing modules?
    • I'm not 100% sure about this, but probably no. Some weight formats make sense only when used with backends on a specific device, and people can freely remap their storages when loading params.
  • Is it OK to add/remove/change parameters while the module is locked?
    • I'd say yes, because it makes life easier - say that we replace all RNN parameters with a single flat cuDNN vector. Then, we wouldn't need any additional logic in DataParallel to handle replication in correct format.
    • On the other hand it's more annoying for the users, but oh well.
  • Do type casts/device changes restore the weights to standard format?
    • Type casts no, device changes yes.
  • Do we still use cuDNN for RNNs on GPU by default?
    • I suggest the following heuristic: if the sequence length is less or equal N then no (cuDNN doesn't give us huge benefits for short sequences and causes the memory to blow up). N would be somewhere in the range of 1-5 (based on some benchmarks). It's annoying to force users to set the backend themselves, but I don't have a better idea. Hopefully the JIT will help here anyway.
  • Can setting format change the number of parameters?
    • Yes. We need to add fake biases for cuDNN RNNs even if they don't exist in standard format.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment