Skip to content

Instantly share code, notes, and snippets.

@cenkbircanoglu
Created January 27, 2023 21:55
Show Gist options
  • Save cenkbircanoglu/935ca8714bc38299d65a41c7bcd9d21a to your computer and use it in GitHub Desktop.
Save cenkbircanoglu/935ca8714bc38299d65a41c7bcd9d21a to your computer and use it in GitHub Desktop.
unet model to use in isim
import torch
import torch.nn.functional as F
from models.unet.net import Net
class CAM(Net):
def __init__(self, *args, **kwargs):
super(CAM, self).__init__(*args, **kwargs)
def forward(self, x):
return self.forward_cam(x)
def forward_cam(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
weights = torch.zeros_like(self.classifier.weight)
with torch.no_grad():
weights.set_(self.classifier.weight.detach())
x = F.relu(F.conv2d(bottleneck, weight=weights))
x = x[0] + x[1].flip(-1)
return x
if __name__ == '__main__':
from torchsummary import summary
model = CAM()
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
y = model(x)
print(y.shape)
assert y.shape == (20, 20, 20)
model = CAM(init_features=32)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
y = model(x)
assert y.shape == (20, 20, 20)
model = CAM(mid_ch=64)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
y = model(x)
assert y.shape == (20, 20, 20)
from collections import OrderedDict
import torch
import torch.nn as nn
from models.layers import gap2d
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
class Net(nn.Module):
def __init__(self, in_ch=3, mid_ch=32, out_ch=1, num_classes=20, *args, **kwargs):
super(Net, self).__init__()
self.out_channels = out_ch
self.num_classes = num_classes
self.encoder1 = _block(in_ch, mid_ch, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = _block(mid_ch, mid_ch * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = _block(mid_ch * 2, mid_ch * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = _block(mid_ch * 4, mid_ch * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = _block(mid_ch * 8, mid_ch * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(
mid_ch * 16, mid_ch * 8, kernel_size=2, stride=2
)
self.decoder4 = _block((mid_ch * 8) * 2, mid_ch * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
mid_ch * 8, mid_ch * 4, kernel_size=2, stride=2
)
self.decoder3 = _block((mid_ch * 4) * 2, mid_ch * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
mid_ch * 4, mid_ch * 2, kernel_size=2, stride=2
)
self.decoder2 = _block((mid_ch * 2) * 2, mid_ch * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
mid_ch * 2, mid_ch, kernel_size=2, stride=2
)
self.decoder1 = _block(mid_ch * 2, mid_ch, name="dec1")
self.conv = nn.Conv2d(
in_channels=mid_ch, out_channels=out_ch, kernel_size=1
)
self.classifier = nn.Conv2d(mid_ch * 16, num_classes, 1, bias=False)
self.encoder_modules = nn.ModuleList(
[self.encoder1, self.pool1, self.encoder2, self.pool2, self.encoder3, self.pool3, self.encoder4, self.pool4,
self.bottleneck])
self.decoder_modules = nn.ModuleList(
[self.upconv4, self.decoder4, self.upconv3, self.decoder3, self.upconv2, self.decoder2, self.upconv1,
self.decoder1, self.conv])
self.classifier_modules = nn.ModuleList([self.classifier])
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
cls_label_pred = gap2d(bottleneck, keepdims=True)
cls_label_pred = self.classifier(cls_label_pred)
cls_label_pred = cls_label_pred.view(-1, self.num_classes)
return cls_label_pred
def trainable_parameters(self):
return (list(self.encoder_modules.parameters()), list(self.classifier_modules.parameters()))
if __name__ == '__main__':
from torchsummary import summary
model = Net()
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
y = model(x)
print(y.shape)
assert y.shape == (2, 20)
model = Net(mid_ch=32)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
y = model(x)
assert y.shape == (2, 20)
model = Net(mid_ch=64)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
y = model(x)
assert y.shape == (2, 20)
import torch
from models.generate_label import generate_pseudo_label
from models.layers import gap2d
from models.unet.cam import CAM
from utils import count_parameters
class Segmentation(CAM):
def __init__(self, *args, **kwargs):
super(Segmentation, self).__init__(*args, **kwargs)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
cls_label_pred = gap2d(bottleneck, keepdims=True)
cls_label_pred = self.classifier(cls_label_pred)
cls_label_pred = cls_label_pred.view(-1, self.num_classes)
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
seg_label_pred = self.conv(dec1)
return cls_label_pred, seg_label_pred
def trainable_parameters(self):
return (list(self.encoder_modules.parameters()), list(self.classifier_modules.parameters()),
list(self.decoder_modules.parameters()))
if __name__ == '__main__':
from torchsummary import summary
model = Segmentation(in_ch=3, out_ch=21, mid_ch=32, num_classes=20)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
cls_pred, seg_pred = model(x)
assert cls_pred.shape == (2, 20)
assert seg_pred.shape == (2, 21, 320, 320)
model = Segmentation(mid_ch=32)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
cls_pred, seg_pred = model(x)
assert cls_pred.shape == (2, 20)
assert seg_pred.shape == (2, 1, 320, 320)
model = Segmentation(mid_ch=64)
summary(model, input_size=(3, 320, 320))
x = torch.rand([2, 3, 320, 320])
cls_pred, seg_pred = model(x)
assert cls_pred.shape == (2, 20)
assert seg_pred.shape == (2, 1, 320, 320)
## Test Generating PSEUDO Labels
imgs = torch.rand([1, 1, 2, 3, 320, 320])
cam, keys = generate_pseudo_label(model, imgs, torch.Tensor([1, 0, 0, 1, 0, 0, 0]), (512, 512))
assert cam.shape == (512, 512)
count1 = count_parameters(model)
count2 = 0
for p in model.trainable_parameters():
for pp in p:
if pp.requires_grad:
count2 += pp.numel()
assert count1 == count2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment