Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Last active September 9, 2020 19:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sshleifer/b838719ac5594a435b2dca5e67335c2f to your computer and use it in GitHub Desktop.
Save sshleifer/b838719ac5594a435b2dca5e67335c2f to your computer and use it in GitHub Desktop.

Problem:

  • In WMT datasets, there is wide variation in the length of examples. Some are one sentence. Some are 10 sentences.
  • The max batch size that can fit on a v100 is roughly (4, 512)
  • you end up with lots of batches of shape (4, 12) or (4, small_int) which don't fully utilize the GPU.

Dynamic Batch Size: try to organize batches to be 4*512=2048 tokens, one batch might be shaped (4,512) another (32, 64).

Details of Fairseq Solution:

pass

batch_sampler:List[List[int]] =[[id_0_batch0, id1_batch0], [id_3764_batch_1], [id_3_batch_2, id_4_batch_2, id_5_batch_2]

kwarg to DataLoader.

Each entry in the list is the examples that compose a batch. The entries don't need to be the same length. procedure is

  1. (OPTIONAL) sort examples by length (to save padding)
  2. pack every entry in the list such that the included examples total at most max_tokens=4000, (this includes padding).

Then the batches are presented in different orders at training time. (but in the above example, id_0 and id_1 would be in the same batch every time).

This made training 40% faster for mbart wmt finetuning without changing metrics.

integration challenges

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment