Skip to content

Instantly share code, notes, and snippets.

@thomasweng15
Created August 24, 2020 15:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasweng15/ec74b8ea054a42b1efce09113b3816ad to your computer and use it in GitHub Desktop.
Save thomasweng15/ec74b8ea054a42b1efce09113b3816ad to your computer and use it in GitHub Desktop.
resnet with multigrid
class ResNet(nn.Module):
def __init__(self,
block,
layers,
num_classes=1000,
fully_conv=False,
remove_avg_pool_layer=False,
output_stride=32,
additional_blocks=0,
multi_grid=(1,1,1) ):
# Add additional variables to track
# output stride. Necessary to achieve
# specified output stride.
self.output_stride = output_stride
self.current_stride = 4
self.current_dilation = 1
self.remove_avg_pool_layer = remove_avg_pool_layer
self.inplanes = 64
self.fully_conv = fully_conv
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
self.additional_blocks = additional_blocks
if additional_blocks == 1:
self.layer5 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
if additional_blocks == 2:
self.layer5 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
self.layer6 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
if additional_blocks == 3:
self.layer5 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
self.layer6 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
self.layer7 = self._make_layer(block, 512, layers[3], stride=2, multi_grid=multi_grid)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
if self.fully_conv:
self.avgpool = nn.AvgPool2d(7, padding=3, stride=1)
# In the latest unstable torch 4.0 the tensor.copy_
# method was changed and doesn't work as it used to be
#self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self,
block,
planes,
blocks,
stride=1,
multi_grid=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
# Check if we already achieved desired output stride.
if self.current_stride == self.output_stride:
# If so, replace subsampling with a dilation to preserve
# current spatial resolution.
self.current_dilation = self.current_dilation * stride
stride = 1
else:
# If not, perform subsampling and update current
# new output stride.
self.current_stride = self.current_stride * stride
# We don't dilate 1x1 convolution.
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
dilation = multi_grid[0] * self.current_dilation if multi_grid else self.current_dilation
layers.append(block(self.inplanes, planes, stride, downsample, dilation=dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
dilation = multi_grid[i] * self.current_dilation if multi_grid else self.current_dilation
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.additional_blocks == 1:
x = self.layer5(x)
if self.additional_blocks == 2:
x = self.layer5(x)
x = self.layer6(x)
if self.additional_blocks == 3:
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
if not self.remove_avg_pool_layer:
x = self.avgpool(x)
if not self.fully_conv:
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment