Skip to content

Instantly share code, notes, and snippets.

@ben0it8
Last active July 17, 2019 09:00
Show Gist options
  • Save ben0it8/b32148f83fef0bacb3acfa3742b550a1 to your computer and use it in GitHub Desktop.
Save ben0it8/b32148f83fef0bacb3acfa3742b550a1 to your computer and use it in GitHub Desktop.
load pretrained NAACL Transformer
from pytorch_transformers import cached_path
# download pre-trained model and config
state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
"naacl-2019-tutorial/model_checkpoint.pth"), map_location='cpu')
config = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
"naacl-2019-tutorial/model_training_args.bin"))
# init model: Transformer base + classifier head
model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)
model.load_state_dict(state_dict, strict=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment