Skip to content

Instantly share code, notes, and snippets.

@phillies
Created September 19, 2019 05:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phillies/98181ab21d34cef56837476551b5512f to your computer and use it in GitHub Desktop.
Save phillies/98181ab21d34cef56837476551b5512f to your computer and use it in GitHub Desktop.
Using a pretrained Unet with any number of input channels
import torch
import segmentation_models_pytorch as smp
in_channels = 8
# load pretrained Unet with resnet34 encoder and imagenet weights
net = smp.Unet()
# Checking the parameters of the first layer
print(net.encoder.conv1)
# output: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# This will replace the first layer and you will lose the pretrained weigts for that layer
net.encoder.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# test
x = torch.randn(1,in_channels,128,128) # test image
y = net(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment