Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
PyTorch Loading Pre-trained Models
# 1. Directly Load a Pre-trained Model
# https://github.com/pytorch/vision/tree/master/torchvision/models
import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)
# or
model = models.resnet50(pretrained=False)
# Maybe you want to modify the last fc layer?
resnet.fc = nn.Linear(2048, 2)
# 2. Load part of parameters of a pretrained model as init for self-defined similar-architecture model.
# resnet50 is a pretrain model
# self_defined indicates model you just define.
resnet50 = models.resnet50(pretrained=True)
self_defined = Net(...)
pretrained_dict = resnet50.state_dict()
model_dict = self_defined.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# update & load
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
# 3. Save & Load routines.
# routine 1
# torch.save(model.state_dict(), PATH)
# model = ModelClass(*args, **kwargs)
# model.load_state_dict(torch.load(PATH))
# routine 2
# torch.save(model, PATH)
# model = torch.load(PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.