Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Last active August 19, 2020 18:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sshleifer/cba08bc2109361a74ac3760a7e30e4f4 to your computer and use it in GitHub Desktop.
Save sshleifer/cba08bc2109361a74ac3760a7e30e4f4 to your computer and use it in GitHub Desktop.
breakpoint at /home/shleifer/fairseq/fairseq/tasks/fairseq_task.py(385)train_step()

(first, wget fairseq_wmt_enro.tgz from s3)

During training, fairseq passes mbart dynamically sized batches (up to 128 tokens), in a dict called sample with the following relevant keys:

  • target (our labels): no bos, ends with [2, tgt_lang_code]
  • net_input.src_tokens (our input_ids): ends with [2, 250004]
  • net_input.prev_output_tokens (our decoder_input_ids): startswith 250020, ends with 2 . This is the "shift_tokens_right" version of target.

Here are the logs from my breakpoint:

ipdb> sample.keys()
dict_keys(['id', 'nsentences', 'ntokens', 'net_input', 'target'])
ipdb> sample['net_input'].keys()
dict_keys(['src_tokens', 'src_lengths', 'prev_output_tokens'])
ipdb> sample['target'][0]
tensor([  9345,    202,     10, 181684,     36,  21635,   8454,  48993,  45587,
            21,  57476,   1283,  98748,    451,    346,   8916,    202,     28,
             9,      7,    451,  11650, 128402,      5,      2, 250020],
       device='cuda:0')
ipdb> sample['net_input']['src_tokens'][0]
tensor([   581,   4738,  30666,    297,     10,  21635,   1363,     98,  28811,
           552,      9,  21473,   1363,     23,     70,     28,      9,  94005,
          8916,      5,      2, 250004], device='cuda:0')
ipdb> sample['net_input']['prev_output_tokens'][0]
tensor([250020,   9345,    202,     10, 181684,     36,  21635,   8454,  48993,
         45587,     21,  57476,   1283,  98748,    451,    346,   8916,    202,
            28,      9,      7,    451,  11650, 128402,      5,      2],
       device='cuda:0')
 fairseq-train fairseq_wmt_enro  --encoder-normalize-before --decoder-normalize-before  --arch mbart_large \
 --task translation_from_pretrained_bart  --source-lang $SRC --target-lang $TGT --criterion label_smoothed_cross_entropy \
 --label-smoothing 0.2  --dataset-impl mmap --optimizer adam \
 --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler polynomial_decay --lr 3e-4 --min-lr -1 \
--warmup-updates 2500 --total-num-update 300000 --dropout 0.2 --attention-dropout 0.1 \
--weight-decay 0.0 --max-tokens 128 --update-freq 2 --save-interval 1 --save-interval-updates 5000 \
--keep-interval-updates 3 --no-epoch-checkpoints --seed 222 \
--log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
--restore-file $PRETRAIN --langs $langs --layernorm-embedding  \
--ddp-backend no_c10d --save-dir $DEST --memory-efficient-fp16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment