Last active
May 5, 2020 18:47
-
-
Save jwuphysics/eb388eb3dcee4ccac84f8174a6915bc6 to your computer and use it in GitHub Desktop.
adding deconvolution stems to fastai2 xresnets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Combining fastai2 xresnet with deconv stem | |
Hosted here: https://gist.github.com/jwuphysics/eb388eb3dcee4ccac84f8174a6915bc6 | |
https://github.com/fastai/fastai2/blob/master/fastai2/vision/models/xresnet.py | |
https://github.com/yechengxi/deconvolution/blob/master/models/resnet.py | |
""" | |
from fastai2.vision.all import * | |
from deconvolution.models.deconv import FastDeconv # from `deconvolution` repo | |
class ConvLayer_deconv(nn.Sequential): | |
"Create a sequence of deconv (`ni` to `nf`) and ReLU/Mish (if `use_activ`) layers." | |
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=True, ndim=2, | |
act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs): | |
if padding is None: padding = ((ks-1)//2 if not transpose else 0) | |
conv_func = FastDeconv | |
conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs) | |
act = None if act_cls is None else act_cls() | |
init_linear(conv, act, init=init, bias_std=bias_std) | |
layers = [conv] | |
if act: layers.append(act) | |
if xtra: layers.append(xtra) | |
super().__init__(*layers) | |
class XResNet_deconv(nn.Sequential): | |
@delegates(ResBlock) | |
def __init__(self, block, expansion, layers, p=0.0, c_in=3, n_out=1000, stem_szs=(32,32,64), | |
widen=1.0, sa=False, act_cls=defaults.activation, **kwargs): | |
store_attr(self, 'block,expansion,act_cls') | |
stem_szs = [c_in, *stem_szs] | |
stem = [ConvLayer_deconv(stem_szs[i], stem_szs[i+1], stride=2 if i==0 else 1, act_cls=act_cls) | |
for i in range(3)] | |
block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)] | |
block_szs = [64//expansion] + block_szs | |
blocks = self._make_blocks(layers, block_szs, sa, **kwargs) | |
super().__init__( | |
*stem, nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
*blocks, | |
nn.AdaptiveAvgPool2d(1), Flatten(), nn.Dropout(p), | |
nn.Linear(block_szs[-1]*expansion, n_out), | |
) | |
init_cnn(self) | |
def _make_blocks(self, layers, block_szs, sa, **kwargs): | |
return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, | |
stride=1 if i==0 else 2, sa=sa and i==len(layers)-4, **kwargs) | |
for i,l in enumerate(layers)] | |
def _make_layer(self, ni, nf, blocks, stride, sa, **kwargs): | |
return nn.Sequential( | |
*[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1, | |
sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs) | |
for i in range(blocks)]) | |
def _xresnet_deconv(pretrained, expansion, layers, **kwargs): | |
return XResNet_deconv(ResBlock, expansion, layers, **kwargs) | |
def xresnet18_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [2, 2, 2, 2], **kwargs) | |
def xresnet34_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [3, 4, 6, 3], **kwargs) | |
def xresnet50_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3, 4, 6, 3], **kwargs) | |
def xresnet101_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3, 4, 23, 3], **kwargs) | |
def xresnet152_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3, 8, 36, 3], **kwargs) | |
def xresnet18_deep_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [2,2,2,2,1,1], **kwargs) | |
def xresnet34_deep_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [3,4,6,3,1,1], **kwargs) | |
def xresnet50_deep_deconv (pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3,4,6,3,1,1], **kwargs) | |
def xresnet18_deeper_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [2,2,1,1,1,1,1,1], **kwargs) | |
def xresnet34_deeper_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 1, [3,4,6,3,1,1,1,1], **kwargs) | |
def xresnet50_deeper_deconv(pretrained=False, **kwargs): return _xresnet_deconv(pretrained, 4, [3,4,6,3,1,1,1,1], **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment