Skip to content

Instantly share code, notes, and snippets.

@FlorianMuellerklein
Forked from chauhan-utk/file.md
Created August 3, 2018 18:03
Show Gist options
  • Save FlorianMuellerklein/399fef8c4e4c97dbfb8bda1f159707ad to your computer and use it in GitHub Desktop.
Save FlorianMuellerklein/399fef8c4e4c97dbfb8bda1f159707ad to your computer and use it in GitHub Desktop.
PyTorch replace pretrained model layers

This code snippet shows how we can change a layer in a pretrained model. In the following code, we change all the ReLU activation functions with SELU in a resnet18 model.

import torch
from torchvision import model

resnet18 = model.resnet18(pretrained=True)

def funct(list_mods):
  print("type: ", type(list_mods))
  for i in range(len(list_mods)):
    if list_mods[i].__class__.__name__ == "ReLU":
      list_mods[i] = nn.SELU(inplace=True)
      elif list_mods[i].__class__.__name__ in ("Sequential", "BasicBlock"):
        list_mods[i] = nn.Sequential(*funct(list(list_mods[i].children())))
        return list_mods
      
resnet18_selu = nn.Sequential(*funct(list(resnet18.children())))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment