Created
September 19, 2019 05:27
-
-
Save phillies/98181ab21d34cef56837476551b5512f to your computer and use it in GitHub Desktop.
Using a pretrained Unet with any number of input channels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import segmentation_models_pytorch as smp | |
in_channels = 8 | |
# load pretrained Unet with resnet34 encoder and imagenet weights | |
net = smp.Unet() | |
# Checking the parameters of the first layer | |
print(net.encoder.conv1) | |
# output: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
# This will replace the first layer and you will lose the pretrained weigts for that layer | |
net.encoder.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
# test | |
x = torch.randn(1,in_channels,128,128) # test image | |
y = net(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment