Skip to content

Instantly share code, notes, and snippets.

@RamonYeung
Created July 12, 2018 07:46
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RamonYeung/988945c805938636fc85c5385bd3d1b4 to your computer and use it in GitHub Desktop.
Save RamonYeung/988945c805938636fc85c5385bd3d1b4 to your computer and use it in GitHub Desktop.
PyTorch Loading Pre-trained Models
# 1. Directly Load a Pre-trained Model
# https://github.com/pytorch/vision/tree/master/torchvision/models
import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)
# or
model = models.resnet50(pretrained=False)
# Maybe you want to modify the last fc layer?
resnet.fc = nn.Linear(2048, 2)
# 2. Load part of parameters of a pretrained model as init for self-defined similar-architecture model.
# resnet50 is a pretrain model
# self_defined indicates model you just define.
resnet50 = models.resnet50(pretrained=True)
self_defined = Net(...)
pretrained_dict = resnet50.state_dict()
model_dict = self_defined.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# update & load
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
# 3. Save & Load routines.
# routine 1
# torch.save(model.state_dict(), PATH)
# model = ModelClass(*args, **kwargs)
# model.load_state_dict(torch.load(PATH))
# routine 2
# torch.save(model, PATH)
# model = torch.load(PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment