Skip to content

Instantly share code, notes, and snippets.

@ai2ys
Created April 15, 2023 15:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ai2ys/fc12aeab655de7ca4d3e1d1993ef7fd1 to your computer and use it in GitHub Desktop.
Save ai2ys/fc12aeab655de7ca4d3e1d1993ef7fd1 to your computer and use it in GitHub Desktop.
PyTorch - Partially initialize model wth pretrained weight and partially freeze
import logging
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights
class Vgg16PartiallyFrozenFeatures(nn.Module):
def __init__(self, num_frozen_blocks=3):
super(Vgg16PartiallyFrozenFeatures, self).__init__()
self.model = models.vgg16(weights=None)
vgg16_pretrained = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
vgg16_features_state_dict = vgg16_pretrained.features.state_dict()
# self.features_frozen = nn.Sequential()
# self.features_trainable = nn.Sequential()
# block_count = 0
# for i, child in enumerate(self.model.features.children()):
# # print(i, child)
# if isinstance(child, nn.MaxPool2d):
# block_count += 1
# if block_count < num_frozen_blocks:
# self.features_frozen.append(child)
# else:
# self.features_trainable.append(child)
# vgg16_features_state_dict = vgg16_pretrained.features.state_dict()
# incompatible_keys = self.features_frozen.load_state_dict(vgg16_features_state_dict, strict=False)
# logging.info("Missing keys:", incompatible_keys.missing_keys)
# logging.info("Unexpected keys:", incompatible_keys.unexpected_keys)
# self.features_frozen.requires_grad_(False)
# self.model.features = nn.Sequential(self.features_frozen, self.features_trainable)
self.frozen_childs = nn.Sequential()
block_count = 0
for i, child in enumerate(self.model.features.children()):
if isinstance(child, nn.MaxPool2d):
block_count += 1
if block_count < num_frozen_blocks:
self.frozen_childs.append(child)
logging.debug(f"freezing child: {child._get_name()}")
vgg16_features_state_dict = vgg16_pretrained.features.state_dict()
incompatible_keys = self.frozen_childs.load_state_dict(vgg16_features_state_dict, strict=False)
self.frozen_childs.requires_grad_(False)
logging.debug(f"Missing keys: {incompatible_keys.missing_keys}")
logging.debug(f"Unexpected keys: {incompatible_keys.unexpected_keys}")
if logging.root.level ==logging.DEBUG:
logging.debug("Comparing pretrained to custom weights - ")
for (kr, vr), (kp, vp) in zip(self.model.state_dict().items(), vgg16_pretrained.state_dict().items()):
logging.debug(f"equal: {kr}, {torch.allclose(vr, vp)}")
# def train(self, mode: bool = True):
# self.model.train(mode)
# self.features_frozen.eval()
# print("requires grad...")
# for i, (name, param) in enumerate(self.model.named_parameters()):
# print(i, name, param.requires_grad)
# print("training...")
# for i, (name, module) in enumerate(self.model.named_modules()):
# print(i, name, module.training)
# # self.features_frozen.requires_grad_(False)
def train(self, mode: bool = True):
self.model.train(mode)
self.frozen_childs.eval()
if logging.root.level == logging.DEBUG:
logging.debug("requires grad...")
for i, (name, param) in enumerate(self.model.named_parameters()):
logging.debug(f"{i}, {name}, {param.requires_grad}")
logging.debug("training...")
for i, (name, module) in enumerate(self.model.named_modules()):
logging.debug(f"{i}, {name}, {module.training}")
def forward(self, x):
return self.model.forward(x)
# %%
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s')
test = Vgg16PartiallyFrozenFeatures()
# %%
test.train()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment