Skip to content

Instantly share code, notes, and snippets.

@jwuphysics
Last active May 5, 2020 18:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jwuphysics/eb388eb3dcee4ccac84f8174a6915bc6 to your computer and use it in GitHub Desktop.
Save jwuphysics/eb388eb3dcee4ccac84f8174a6915bc6 to your computer and use it in GitHub Desktop.
adding deconvolution stems to fastai2 xresnets
"""Combining fastai2 xresnet with deconv stem
Hosted here: https://gist.github.com/jwuphysics/eb388eb3dcee4ccac84f8174a6915bc6
https://github.com/fastai/fastai2/blob/master/fastai2/vision/models/xresnet.py
https://github.com/yechengxi/deconvolution/blob/master/models/resnet.py
"""
from fastai2.vision.all import *
from deconvolution.models.deconv import FastDeconv # from `deconvolution` repo
class ConvLayer_deconv(nn.Sequential):
"Create a sequence of deconv (`ni` to `nf`) and ReLU/Mish (if `use_activ`) layers."
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=True, ndim=2,
act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs):
if padding is None: padding = ((ks-1)//2 if not transpose else 0)
conv_func = FastDeconv
conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs)
act = None if act_cls is None else act_cls()
init_linear(conv, act, init=init, bias_std=bias_std)
layers = [conv]
if act: layers.append(act)
if xtra: layers.append(xtra)
super().__init__(*layers)
class XResNet_deconv(nn.Sequential):
@delegates(ResBlock)
def __init__(self, block, expansion, layers, p=0.0, c_in=3, n_out=1000, stem_szs=(32,32,64),
widen=1.0, sa=False, act_cls=defaults.activation, **kwargs):
store_attr(self, 'block,expansion,act_cls')
stem_szs = [c_in, *stem_szs]
stem = [ConvLayer_deconv(stem_szs[i], stem_szs[i+1], stride=2 if i==0 else 1, act_cls=act_cls)
for i in range(3)]
block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
block_szs = [64//expansion] + block_szs
blocks = self._make_blocks(layers, block_szs, sa, **kwargs)
super().__init__(
*stem, nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
*blocks,
nn.AdaptiveAvgPool2d(1), Flatten(), nn.Dropout(p),
nn.Linear(block_szs[-1]*expansion, n_out),
)
init_cnn(self)
def _make_blocks(self, layers, block_szs, sa, **kwargs):
return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l,
stride=1 if i==0 else 2, sa=sa and i==len(layers)-4, **kwargs)
for i,l in enumerate(layers)]
def _make_layer(self, ni, nf, blocks, stride, sa, **kwargs):
return nn.Sequential(
*[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1,
sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs)
for i in range(blocks)])
def _xresnet_deconv(pretrained, expansion, layers, **kwargs):
return XResNet_deconv(ResBlock, expansion, layers, **kwargs)
def xresnet18_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [2, 2, 2, 2], **kwargs)
def xresnet34_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [3, 4, 6, 3], **kwargs)
def xresnet50_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3, 4, 6, 3], **kwargs)
def xresnet101_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3, 4, 23, 3], **kwargs)
def xresnet152_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3, 8, 36, 3], **kwargs)
def xresnet18_deep_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [2,2,2,2,1,1], **kwargs)
def xresnet34_deep_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [3,4,6,3,1,1], **kwargs)
def xresnet50_deep_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3,4,6,3,1,1], **kwargs)
def xresnet18_deeper_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [2,2,1,1,1,1,1,1], **kwargs)
def xresnet34_deeper_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [3,4,6,3,1,1,1,1], **kwargs)
def xresnet50_deeper_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3,4,6,3,1,1,1,1], **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment