Skip to content

Instantly share code, notes, and snippets.

@gabrieldernbach
Created June 6, 2022 23:52
Show Gist options
  • Save gabrieldernbach/1b3d1fdb44233fa2e64877a16097cc7b to your computer and use it in GitHub Desktop.
Save gabrieldernbach/1b3d1fdb44233fa2e64877a16097cc7b to your computer and use it in GitHub Desktop.
import torch.nn as nn
class Residual(nn.Module):
def __init__(self, dim):
super().__init__()
self.layer = nn.Sequential(
nn.Conv2d(dim, dim, 7, 1, 3, groups=dim),
nn.BachNorm2d(dim),
nn.Conv2d(dim, dim*4, 1),
nn.ReLU(),
nn.Conv2d(dim*4, dim, 1),
)
def forward(self, x):
return self.layer(x) + x
def block(ins, outs, repeats):
layers = [Residual(ins) for _ in range(repeats)]
if in != outs:
layers.extend([
nn.BatchNorm2d(ins),
nn.Conv2d(ins, outs, 2, 2),
])
return nn.Sequential(*layers)
def convnet18(outs):
features = nn.Sequential(
nn.Conv2d(3, 64, 4, 4) # e.g. (3, 224, 224) -> (64, 56, 56)
block(64, 128, 2),
block(128, 256, 2),
block(256, 512, 2),
block(512, 512, 2), # remaining (512, 7, 7))
nn.AdaptiveAvgPool2d(1), # remaining (512)
nn.Flatten(),
)
return nn.Sequential(features, nn.Linear(512, outs))
# This network is a modification of resnet18 inspired by
# the observations reported in https://arxiv.org/abs/2201.03545 (A ConvNet for the 2020s)
# we try to remain close in n_params to renset18's 11 mio (achieved 12mio).
# comparison to resnet18 (and following A ConvNet for the 2020s)
# stronger stemming (factor two vs factor 4)
# wider kernels (7 vs 3)
# less normalization layers
# less activation layers
# use of depth-wise conv (much faster to compute!)
# inverted residual in the point-wise convs (more expressive)
# different to (A ConvNet for the 2020s)
# no ELU activation (slow to compute)
# no transpose/layernorm (slow to compute)
# less repeats (exploding parameter count, slow to compute)
# no stochastic depth (diminishing returns with less layers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment