(first, wget fairseq_wmt_enro.tgz from s3)
During training, fairseq passes mbart dynamically sized batches (up to 128 tokens), in a dict called sample
with the following relevant keys:
- target (our
labels
): no bos, ends with[2, tgt_lang_code]
- net_input.src_tokens (our
input_ids
): ends with[2, 250004]
- net_input.prev_output_tokens (our
decoder_input_ids
): startswith 250020, ends with 2 . This is the "shift_tokens_right" version oftarget
.
Here are the logs from my breakpoint:
ipdb> sample.keys()
dict_keys(['id', 'nsentences', 'ntokens', 'net_input', 'target'])
ipdb> sample['net_input'].keys()
dict_keys(['src_tokens', 'src_lengths', 'prev_output_tokens'])
ipdb> sample['target'][0]
tensor([ 9345, 202, 10, 181684, 36, 21635, 8454, 48993, 45587,
21, 57476, 1283, 98748, 451, 346, 8916, 202, 28,
9, 7, 451, 11650, 128402, 5, 2, 250020],
device='cuda:0')
ipdb> sample['net_input']['src_tokens'][0]
tensor([ 581, 4738, 30666, 297, 10, 21635, 1363, 98, 28811,
552, 9, 21473, 1363, 23, 70, 28, 9, 94005,
8916, 5, 2, 250004], device='cuda:0')
ipdb> sample['net_input']['prev_output_tokens'][0]
tensor([250020, 9345, 202, 10, 181684, 36, 21635, 8454, 48993,
45587, 21, 57476, 1283, 98748, 451, 346, 8916, 202,
28, 9, 7, 451, 11650, 128402, 5, 2],
device='cuda:0')
fairseq-train fairseq_wmt_enro --encoder-normalize-before --decoder-normalize-before --arch mbart_large \
--task translation_from_pretrained_bart --source-lang $SRC --target-lang $TGT --criterion label_smoothed_cross_entropy \
--label-smoothing 0.2 --dataset-impl mmap --optimizer adam \
--adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler polynomial_decay --lr 3e-4 --min-lr -1 \
--warmup-updates 2500 --total-num-update 300000 --dropout 0.2 --attention-dropout 0.1 \
--weight-decay 0.0 --max-tokens 128 --update-freq 2 --save-interval 1 --save-interval-updates 5000 \
--keep-interval-updates 3 --no-epoch-checkpoints --seed 222 \
--log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
--restore-file $PRETRAIN --langs $langs --layernorm-embedding \
--ddp-backend no_c10d --save-dir $DEST --memory-efficient-fp16