Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Created May 6, 2019 07:17
Show Gist options
  • Save KeremTurgutlu/72dd8d5a5bf6e6272e499d854c479232 to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/72dd8d5a5bf6e6272e499d854c479232 to your computer and use it in GitHub Desktop.
Ventricle Models
from fastai.vision import *
import math
__all__ = ['MeshNet', 'VolumetricUnet', 'conv_relu_bn_drop', 'res3dmodel', 'get_total_params',
'VolumetricResidualUnet', 'model_dict', 'experiment_model_dict', 'one_by_one_conv',
'model_split_dict']
####################
## GET MODELS ##
####################
# 1 - default unet
def unet_default(**kwargs):
'https://arxiv.org/pdf/1606.06650.pdf'
return VolumetricUnet(in_c=1, out_c=4, n_layers=3, c=1, block_type=conv_relu_bn_drop, **kwargs)
# 2 - unet wider
def unet_wide(**kwargs):
return VolumetricUnet(in_c=1, out_c=8, n_layers=3, c=1, block_type=conv_relu_bn_drop, **kwargs)
# 3 - unet deeper
def unet_deep(**kwargs):
return VolumetricUnet(in_c=1, out_c=4, n_layers=5, c=1, block_type=conv_relu_bn_drop, **kwargs)
# 4 - unet wide deep
def unet_wide_deep(**kwargs):
return VolumetricUnet(in_c=1, out_c=8, n_layers=5, c=1, block_type=conv_relu_bn_drop, **kwargs)
# 5 - default meshnet - dilated FCN
def meshnet():
return MeshNet(in_c=1, out_c=8, num_classes=1, drop=0, dilations=[1,1,1,2,4,8,1,1], kernel_size=3)
# 6 - default modified 3d unet - lower lr=1e-1
def modified_unet():
return Modified3DUNet(in_channels=1, n_classes=1, base_n_filter = 8)
# 6b - wider modified 3d unet
def modified_unet_wide():
return Modified3DUNet(in_channels=1, n_classes=1, base_n_filter = 8)
# 7 - 3d residual model
def res3d():
return res3dmodel()
# residual fused unet
def residual_unet():
'https://arxiv.org/pdf/1802.10508.pdf'
return VolumetricResidualUnet(in_c=8, p=0.2, norm_type='instance', actn='prelu')
# residual fused unet wide
def residual_unet_wide():
'https://arxiv.org/pdf/1802.10508.pdf'
return VolumetricResidualUnet(in_c=12, p=0.2, norm_type='instance', actn='prelu')
model_dict = {
'unet_default': unet_default,
'unet_wide': unet_wide,
'unet_deep': unet_deep,
'unet_wide_deep': unet_wide_deep,
'meshnet': meshnet,
'modified_unet': modified_unet,
'modified_unet_wide': modified_unet_wide,
'res3d': res3d,
'residual_unet': residual_unet,
'residual_unet_wide': residual_unet_wide
}
experiment_model_dict = {
'baseline1': partial(unet_default, p=0., norm_type='batch', actn='relu'), # bce
'baseline2': partial(unet_default, p=0., norm_type='batch', actn='relu'), # dice
'baseline3': partial(unet_default, p=0., norm_type='group', actn='relu'),
'baseline4': partial(unet_default, p=0., norm_type='group', actn='prelu'),
'baseline5': partial(unet_default, p=0.3, norm_type='group', actn='prelu'),
'baseline6': meshnet,
'baseline7': partial(unet_wide, p=0., norm_type='group', actn='prelu'),
'baseline8': partial(unet_deep, p=0., norm_type='group', actn='prelu'),
'baseline9': partial(unet_wide_deep, p=0., norm_type='group', actn='prelu'),
'baseline10': residual_unet,
'baseline11': residual_unet_wide
}
####################
## SPLIT FUNCS ##
####################
def _baseline1_split(m:nn.Module): return (nn.ModuleList([m.downblocks,m.middle]),
m.upblocks,
m.conv_final)
def _baseline6_split(m:nn.Module): return (m.layers[:4], m.layers[4:6], m.layers[7:])
# def _baseline8_split(m:nn.Module): return (nn.ModuleList([m.downblocks, m.middle]),
# m.upblocks[:3],
# m.upblocks[3:],
# m.conv_final)
def _baseline8_split(m:nn.Module): return (nn.ModuleList([m.downblocks, m.middle]),
m.upblocks[:3],
nn.ModuleList([m.upblocks[3:], m.conv_final]))
def _baseline10_split(m:nn.Module): return (nn.ModuleList([m.down1, m.down2, m.down3, m.down4]),
nn.ModuleList([m.middle, m.upblock1]),
nn.ModuleList([m.upblock2, m.upblock3, m.seg2, m.seg3, m.seg_final])
)
model_split_dict = {
'baseline1': _baseline1_split,
'baseline2': _baseline1_split,
'baseline3': _baseline1_split,
'baseline4': _baseline1_split,
'baseline5': _baseline1_split,
'baseline6': _baseline6_split,
'baseline7': _baseline1_split,
'baseline8': _baseline8_split,
'baseline9': _baseline8_split,
'baseline10': _baseline10_split,
'baseline11': _baseline10_split,
}
####################
## LAYERS ##
####################
class RunningBatchNorm(nn.Module):
def __init__(self, nf, mom=0.1, eps=1e-5):
super().__init__()
self.mom, self.eps = mom, eps
self.mults = nn.Parameter(torch.ones (nf,1,1))
self.adds = nn.Parameter(torch.zeros(nf,1,1))
self.register_buffer('sums', torch.zeros(1,nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
self.register_buffer('count', tensor(0.))
self.register_buffer('factor', tensor(0.))
self.register_buffer('offset', tensor(0.))
self.batch = 0
def update_stats(self, x):
bs,nc,*_ = x.shape
self.sums.detach_()
self.sqrs.detach_()
dims = (0,2,3)
s = x .sum(dims, keepdim=True)
ss = (x*x).sum(dims, keepdim=True)
c = s.new_tensor(x.numel()/nc)
mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
self.sums .lerp_(s , mom1)
self.sqrs .lerp_(ss, mom1)
self.count.lerp_(c , mom1)
self.batch += bs
means = self.sums/self.count
varns = (self.sqrs/self.count).sub_(means*means)
if bool(self.batch < 20): varns.clamp_min_(0.01)
self.factor = self.mults / (varns+self.eps).sqrt()
self.offset = self.adds - means*self.factor
def forward(self, x):
if self.training: self.update_stats(x)
return x*self.factor + self.offset
def get_total_params(model):
params = model.parameters()
tot_params = 0
for p in params:
prod = np.product(p.shape)
tot_params += prod
print(p.shape, prod)
print('total:', tot_params)
return tot_params
def maxpool3D(): return nn.MaxPool3d(2, stride=2)
def one_by_one_conv(in_channel, out_channel): return nn.Conv3d(in_channel, out_channel, 1)
def conv_relu_bn_drop(in_channel, out_channel, dilation=1, p=0.5, norm_type='batch', actn='relu', init_func=None, stride=1, kernel=3):
'conv (pad=dilation) - same padding -> relu -> norm -> dropout'
if norm_type == 'batch': norm = nn.BatchNorm3d(out_channel)
if norm_type == 'instance': norm = nn.InstanceNorm3d(out_channel)
if norm_type == 'group': norm = nn.GroupNorm(2, out_channel)
if norm_type == 'running': norm = RunningBatchNorm(out_channel)
if actn == 'relu': actn_fn = nn.ReLU(inplace=True)
if actn == 'prelu': actn_fn = nn.PReLU()
conv = nn.Conv3d(in_channel, out_channel, kernel_size=kernel, stride=stride, padding=dilation, dilation=dilation, bias=True)
if init_func: init_default(conv, init_func)
drop = nn.Dropout3d(p)
return nn.Sequential(conv, actn_fn, norm, drop)
####################
## MESHNET ##
####################
class MeshNet(nn.Module):
# https://arxiv.org/pdf/1612.00940.pdf
def __init__(self, in_c=1, out_c=24, num_classes=1, drop=0,
dilations=[1,1,1,2,4,8,1,1], kernel_size=3):
super(MeshNet, self).__init__()
self.layers = []
n = len(dilations[1:-1])
self.layers += [conv_relu_bn_drop(in_c, out_c, dilation=dilations[0], p=drop, norm_type='group',
actn='prelu', init_func=nn.init.kaiming_normal_)]
for d,p,c in zip(dilations[:n], [drop]*n, [out_c]*n):
self.layers += [conv_relu_bn_drop(c, c, dilation=d, p=drop, norm_type='group',
actn='prelu', init_func=nn.init.kaiming_normal_)]
self.layers += [one_by_one_conv(out_c, num_classes)]
self.layers = nn.Sequential(*self.layers)
def forward(self, x): return self.layers(x)
####################
## UNET ##
####################
class UnetBegin(nn.Module):
'conv -> conv -> maxpool'
def __init__(self, block_type, in_c, out_c, **kwargs):
super(UnetBegin, self).__init__()
self.conv_block1 = block_type(in_c, out_c, **kwargs)
self.conv_block2 = block_type(out_c, out_c*2, **kwargs)
self.pool = maxpool3D()
def forward(self, x):
x = self.conv_block2(self.conv_block1(x))
return self.pool(x), x
class UnetEnd(nn.Module):
'conv -> conv -> maxpool 2**i + 2**(i+1), 2**i = in_c'
def __init__(self, block_type, in_c=64, **kwargs):
super(UnetEnd, self).__init__()
i = int(math.log2(in_c))
self.conv_block1 = block_type(2**i + 2**(i+1), 2**i, **kwargs)
self.conv_block2 = block_type(2**i, 2**i, **kwargs)
def forward(self, x):
return self.conv_block2(self.conv_block1(x))
class UnetDownBlock(nn.Module):
'conv -> conv -> maxpool'
def __init__(self, block_type, in_c=64, **kwargs):
super(UnetDownBlock, self).__init__()
self.conv_block1 = block_type(in_c, in_c, **kwargs)
self.conv_block2 = block_type(in_c, in_c*2, **kwargs)
self.pool = maxpool3D()
def forward(self, x):
x = self.conv_block2(self.conv_block1(x))
return self.pool(x), x
class UnetUpBlock(nn.Module):
'conv -> conv -> maxpool 2**i + 2**(i+1), 2**i = in_c'
def __init__(self, block_type, in_c=64, **kwargs):
super(UnetUpBlock, self).__init__()
i = int(math.log2(in_c))
self.conv_block1 = block_type(2**i + 2**(i+1), 2**i, **kwargs)
self.conv_block2 = block_type(2**i, 2**i, **kwargs)
def forward(self, x):
x = self.conv_block2(self.conv_block1(x))
return F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
class VolumetricUnet(nn.Module):
def __init__(self, in_c=1, out_c=32, n_layers=3, c=1, block_type=conv_relu_bn_drop, **kwargs):
'create a 3d Unet with n_layers and out_c features'
super(VolumetricUnet, self).__init__()
# downblocks
self.downblocks = nn.ModuleList([UnetBegin(block_type, in_c, out_c, **kwargs)] +
[UnetDownBlock(block_type, out_c*2**(i+1), **kwargs) for i in range(n_layers-1)])
# middle
self.middle = nn.Sequential(block_type(out_c*2**(n_layers), out_c*2**(n_layers), **kwargs),
block_type(out_c*2**(n_layers), out_c*2**(n_layers+1), **kwargs))
# upblocks
self.upblocks = nn.ModuleList([UnetEnd(block_type, out_c*2, **kwargs)] +
[UnetUpBlock(block_type, out_c*(2**(i+1)), **kwargs) for i in range(1, n_layers)])
# final conv 1x1
self.conv_final = one_by_one_conv(out_c*2, c)
def forward(self, x):
x_concats = []
for l in self.downblocks:
x, x_concat = l(x)
x_concats += [x_concat]
x = self.middle(x)
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
for l, x_concat in zip(self.upblocks[::-1], x_concats[::-1]):
x = l(torch.cat([x_concat, x], dim=1))
return self.conv_final(x)
#########################
## RES-3D ##
#########################
def conv3d(ni:int, nf:int, ks:int=3, stride:int=1, pad:int=1, norm='batch'):
bias = not norm == 'batch'
conv = init_default(nn.Conv3d(ni,nf,ks,stride,pad,bias=bias))
conv = spectral_norm(conv) if norm == 'spectral' else \
weight_norm(conv) if norm == 'weight' else conv
layers = [conv]
layers += [nn.ReLU(inplace=True)] # use inplace due to memory constraints
layers += [nn.BatchNorm3d(nf)] if norm == 'batch' else []
return nn.Sequential(*layers)
def res3d_block(ni, nf, ks=3, norm='batch', dense=False):
""" 3d Resnet block of `nf` features """
return SequentialEx(conv3d(ni, nf, ks, pad=ks//2, norm=norm),
conv3d(nf, nf, ks, pad=ks//2, norm=norm),
MergeLayer(dense))
def res3dmodel():
norm = 'batch'
layers = ([res3d_block(1,15,7,norm=norm,dense=True)] +
[res3d_block(16,16,norm=norm) for _ in range(4)] +
[conv3d(16,1,ks=1,pad=0,norm=None)])
return nn.Sequential(*layers)
##################################
## Volumetric Residual Unet ##
##################################
def norm_act_conv_drop(in_channel, out_channel, dilation=1, p=0.5, norm_type='instance', actn='prelu', stride=1):
'conv (pad=dilation) - same padding -> relu -> norm -> dropout'
if norm_type == 'batch': norm = nn.BatchNorm3d(out_channel)
if norm_type == 'instance': norm = nn.InstanceNorm3d(out_channel)
if norm_type == 'group': norm = nn.GroupNorm(2, out_channel)
if actn == 'relu': actn_fn = nn.ReLU(inplace=True)
if actn == 'lrelu': actn_fn = nn.LeakyReLU(0.1)
if actn == 'prelu': actn_fn = nn.PReLU()
conv = nn.Conv3d(in_channel, out_channel, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=True)
drop = nn.Dropout3d(p)
return nn.Sequential(norm, actn_fn, conv, drop)
class PreActBlock(nn.Module):
def __init__(self, in_c, **kwargs):
super(PreActBlock, self).__init__()
self.c1 = norm_act_conv_drop(in_c, in_c, **kwargs)
self.c2 = norm_act_conv_drop(in_c, in_c, **kwargs)
def forward(self, x):
return x + self.c2(self.c1(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_c, **kwargs):
super(UpsampleBlock, self).__init__()
self.c = conv_relu_bn_drop(in_c, in_c//2, **kwargs)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
return self.c(x)
class LocalizationBlock(nn.Module):
def __init__(self, in_c, **kwargs):
super(LocalizationBlock, self).__init__()
self.c1 = conv_relu_bn_drop(in_c, in_c, **kwargs)
self.c2 = conv_relu_bn_drop(in_c, in_c//2, 1, **kwargs)
def forward(self, x):
return self.c2(self.c1(x))
class UpBlock(nn.Module):
def __init__(self, in_c, **kwargs):
super(UpBlock, self).__init__()
self.c1 = LocalizationBlock(in_c, **kwargs)
self.c2 = UpsampleBlock(in_c//2, **kwargs)
def forward(self, x):
x1 = self.c1(x)
return self.c2(x1), x1
class SegLayer(nn.Module):
def __init__(self, in_c, **kwargs):
super(SegLayer, self).__init__()
self.c1 = conv_relu_bn_drop(in_c, in_c, 1, p=0.5, norm_type='instance', actn='prelu')
self.c2 = one_by_one_conv(in_c, 1)
def forward(self, x):
return self.c2(self.c1(x))
class DownBlock(nn.Module):
def __init__(self, in_c1, in_c2, first=False, **kwargs):
super(DownBlock, self).__init__()
if first: self.down = conv_relu_bn_drop(in_c1, in_c2, stride=1, **kwargs)
else: self.down = conv_relu_bn_drop(in_c1, in_c2, stride=2, **kwargs)
self.preact_res = PreActBlock(in_c2, **kwargs)
def forward(self, x):
return self.preact_res(self.down(x))
class VolumetricResidualUnet(nn.Module):
def __init__(self, in_c=8, **kwargs):
super(VolumetricResidualUnet, self).__init__()
# downblocks
self.down1 = DownBlock(1, in_c, first=True, **kwargs)
self.down2 = DownBlock(in_c, in_c*2**1, **kwargs)
self.down3 = DownBlock(in_c*2**1, in_c*2**2, **kwargs)
self.down4 = DownBlock(in_c*2**2, in_c*2**3, **kwargs)
# middle
self.middle = nn.Sequential(
conv_relu_bn_drop(in_c*2**3, in_c*2**4, 1, stride=2, **kwargs),
PreActBlock(in_c*2**4, **kwargs),
UpsampleBlock(in_c*2**4, **kwargs)
)
# upblocks
self.upblock1 = UpBlock(in_c*2**4, **kwargs)
self.upblock2 = UpBlock(in_c*2**3, **kwargs)
self.upblock3 = UpBlock(in_c*2**2, **kwargs)
# seg layers
self.seg2 = SegLayer(in_c*2**2)
self.seg3 = SegLayer(in_c*2**1)
self.seg_final = SegLayer(in_c*2**1)
def upsample(self, x):
return F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
def forward(self, x):
out1 = self.down1(x)
out2 = self.down2(out1)
out3 = self.down3(out2)
out4 = self.down4(out3)
middle_out = self.middle(out4)
# concat1
concat1 = torch.cat([out4, middle_out], dim=1)
up1, _ = self.upblock1(concat1)
# concat2
concat2 = torch.cat([out3, up1], dim=1)
up2, up2_segx = self.upblock2(concat2)
# concat3
concat3 = torch.cat([out2, up2], dim=1)
up3, up3_segx = self.upblock3(concat3)
# concat4
concat4 = torch.cat([out1, up3], dim=1)
# segmentation
seg2out = self.seg2(up2_segx)
seg3out = self.seg3(up3_segx)
out_final = self.seg_final(concat4)
return self.upsample(self.upsample(seg2out) + seg3out) + out_final
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment