Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
This gist shows to how define dual-head model for predicting mask & global-level image class
from pytorch_toolbelt.modules import ABN, GlobalAvgPool2d
from pytorch_toolbelt.modules import decoders as D
from pytorch_toolbelt.modules import encoders as E
from torch import nn
from torch.nn import functional as F
class FPNCatSegmentationModel(nn.Module):
def __init__(
self,
num_mask_classes: int,
num_classifer_classes: int,
dropout=0.25,
abn_block=ABN,
fpn_channels=256,
full_size_mask=True,
):
super().__init__()
self.encoder = E.Resnet50Encoder()
self.decoder = D.FPNCatDecoder(
feature_maps=self.encoder.output_filters,
output_channels=num_mask_classes,
dsv_channels=None,
fpn_channels=fpn_channels,
abn_block=abn_block,
dropout=dropout,
)
self.classifier = nn.Sequential(
GlobalAvgPool2d(flatten=True),
nn.Dropout(dropout),
nn.Linear(self.encoder.output_filters[-1], num_classifer_classes)
)
self.full_size_mask = full_size_mask
def forward(self, x):
features = self.encoder(x)
# Decode mask
mask = self.decoder(features)
classifier = self.classifier(features[-1])
if self.full_size_mask:
mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False)
output = {
"mask": mask,
"classifier": classifier,
}
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment