Skip to content

Instantly share code, notes, and snippets.

from fairseq.models.transformer import TransformerModel
zh2en = TransformerModel.from_pretrained(
'/path/to/checkpoints',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='data-bin/wmt17_zh_en_full',
bpe='subword_nmt',
bpe_codes='data-bin/wmt17_zh_en_full/zh.code'
)
zh2en.translate('你好 世界')
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
# Load a transformer trained on WMT'16 En-De
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt')
en2de.eval() # disable dropout
# The underlying model is available under the *models* attribute
@register_task('classification')
class ClassificationTask(FairseqTask): (...)
model.train()
model.set_num_updates(update_num)
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
# setup the task (e.g., load dictionaries)
task = fairseq.tasks.setup_task(args)
# build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
# load datasets
task.load_dataset('train')
task.load_dataset('valid')
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
if args.joined_dictionary:
assert not args.srcdict or not args.tgtdict, \
"cannot use both --srcdict and --tgtdict with --joined-dictionary"
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
def cli_main():
parser = options.get_preprocessing_parser()
args = parser.parse_args()
main(args)
if __name__ == "__main__":
cli_main()