Skip to content

Instantly share code, notes, and snippets.

@phizaz
Last active January 25, 2021 14:15
Show Gist options
  • Save phizaz/9eb7f39f3ef2fd1b9270e5d7d0e66037 to your computer and use it in GitHub Desktop.
Save phizaz/9eb7f39f3ef2fd1b9270e5d7d0e66037 to your computer and use it in GitHub Desktop.
li2018
from segmentation_models_pytorch.encoders import get_encoder
from .mil import MILPool
def make_net_li2018(
backbone,
n_out,
n_in=1,
n_dec_ch=512,
out_size=20,
pooling='milpool',
min_val=0.98,
pretrain='imagenet',
**kwargs,
):
name = f'li2018,out{out_size}-{backbone}-{pooling}-out{n_out}'
if n_in != 1:
name += f'in{n_in}'
if pretrain:
name += f'-pretrain{pretrain}'
if pooling == 'milpool':
if min_val is not None:
name += f',min{min_val}'
@rename(name)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = get_encoder(
name=backbone,
in_channels=n_in,
weights=pretrain,
)
self.out = nn.Sequential(
nn.UpsamplingBilinear2d((out_size, out_size)),
nn.Conv2d(self.net.out_channels[-1], n_dec_ch, 3, padding=1),
nn.BatchNorm2d(n_dec_ch),
nn.ReLU(),
nn.Conv2d(n_dec_ch, n_out, 1, bias=True),
)
pooling_opts = {
'maxpool': nn.AdaptiveMaxPool2d(1),
'avgpool': nn.AdaptiveAvgPool2d(1),
'milpool': MILPool(min_val=min_val, apply_sigmoid=True, ret_logit=True),
}
self.pool = pooling_opts[pooling]
def forward(self, x):
# select the last layer
x = self.net(x)[-1]
seg = self.out(x).float()
pred = self.pool(seg)
pred = torch.flatten(pred, 1)
return {
'pred': pred,
'seg': seg,
}
return Net
def mil_output(p, min_val):
"""
Args:
min_val: cap the min value of 1-p to prevent underflow
"""
n, c, _, _ = p.shape
not_p = 1 - p
not_p = (1-min_val) * not_p + min_val
not_p = not_p.view(n, c, -1).float()
pred = 1 - torch.prod(not_p, dim=-1, keepdim=True)
pred = pred.view(n, c, 1, 1)
return pred
class MILPool(nn.Module):
"""
Multi-instance pooling:
The output is positive when there is at least one positive patch
Found in:
Li, Zhe, Chong Wang, Mei Han, Yuan Xue, Wei Wei, Li-Jia Li, and Li Fei-Fei. 2018.
“Thoracic Disease Identification and Localization with Limited Supervision.”
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 8290–99.
Args:
min_val: None = auto
ret_logit: returns as logit (not prob), to keep the interface invariance
"""
def __init__(self, min_val=0.98, apply_sigmoid=True, ret_logit=False):
super().__init__()
self.min_val = min_val
self.apply_sigmoid = apply_sigmoid
self.ret_logit = ret_logit
def forward(self, x):
n, c, h, w = x.shape
min_val = self.min_val
if self.apply_sigmoid:
x = torch.sigmoid(x)
pred = mil_output(x, min_val=min_val)
if self.ret_logit:
# logit function inverses the sigmoid
pred = torch.log(pred / (1-pred))
return pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment