Skip to content

Instantly share code, notes, and snippets.

@czotti
Created February 21, 2020 15:59
Show Gist options
  • Save czotti/035cc233049410ac9fa6a9f6fda66e8a to your computer and use it in GitHub Desktop.
Save czotti/035cc233049410ac9fa6a9f6fda66e8a to your computer and use it in GitHub Desktop.
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
def bn_to_in_inplace(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.BatchNorm2d):
setattr(model, name, torch.nn.InstanceNorm2d(module.num_features))
else:
bn_to_en(module)
bn_to_in_inplace(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment