Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Created November 18, 2021 20:03
Show Gist options
  • Save InnovArul/871cf866b737b306eb073f3c37a6a8fd to your computer and use it in GitHub Desktop.
Save InnovArul/871cf866b737b306eb073f3c37a6a8fd to your computer and use it in GitHub Desktop.
deepcopy and reset params
import torchvision, copy
import torch, torch.nn as nn
def reset_all_weights(model: nn.Module) -> None:
"""
refs:
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def weight_reset(m: nn.Module):
# - check if the current module has reset_parameters & if it's callabed called it on m
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(fn=weight_reset)
class Densenet121_conv(nn.Module):
def __init__(self, num_classes = [1,4,2,8,3]):
super(Densenet121_conv,self).__init__()
original_model = torchvision.models.densenet121(pretrained=True)
self.num_classes=num_classes
self.trunk=original_model.features[:-2]
common_modules = original_model.features[-2:]
# print(common_modules)
# deepcopy the modules, reset the params
self.branch1 = copy.deepcopy(common_modules)
reset_all_weights(self.branch1)
# print(common_modules.denseblock4.denselayer1.conv2.weight[0,0])
# print(self.branch1.denseblock4.denselayer1.conv2.weight[0,0])
self.branch2 = copy.deepcopy(common_modules)
reset_all_weights(self.branch2)
# print(common_modules.denseblock4.denselayer1.conv2.weight[0,0])
# print(self.branch2.denseblock4.denselayer1.conv2.weight[0,0])
self.branch3 = copy.deepcopy(common_modules)
reset_all_weights(self.branch3)
self.branch4 = copy.deepcopy(common_modules)
reset_all_weights(self.branch4)
self.branch5 = copy.deepcopy(common_modules)
reset_all_weights(self.branch4)
self.classifier1=(nn.Linear(1024, self.num_classes[0]))
self.classifier2=(nn.Linear(1024, self.num_classes[1]))
self.classifier3=(nn.Linear(1024, self.num_classes[2]))
self.classifier4=(nn.Linear(1024, self.num_classes[3]))
self.classifier5=(nn.Linear(1024, self.num_classes[4]))
if __name__ == '__main__':
m = Densenet121_conv()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment