Skip to content

Instantly share code, notes, and snippets.

@aramakus
Created February 22, 2021 10:17
Show Gist options
  • Save aramakus/ab6e8b8da532c36d2cdf3dda9e2f3cac to your computer and use it in GitHub Desktop.
Save aramakus/ab6e8b8da532c36d2cdf3dda9e2f3cac to your computer and use it in GitHub Desktop.
class Residual(nn.Module):
"""
Residual module from "Deep Residual Learning for Image Recognition"
(https://arxiv.org/pdf/1512.03385.pdf)
"""
def __init__(self, inp_C, out_C, stride = 1, batch_norm=True):
super(Residual, self).__init__()
self.conv1 = nn.Conv2d(inp_C, out_C, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_C) if batch_norm else nn.Sequential()
self.a1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_C, out_C, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_C) if batch_norm else nn.Sequential()
changed_shape = (stride != 1) or (inp_C != out_C)
self.skip = nn.Sequential(nn.Conv2d(inp_C, out_C, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_C)) if changed_shape else nn.Sequential()
self.a2 = nn.ReLU(inplace=True)
def forward(self, inp):
x = self.conv1(inp)
x = self.bn1(x)
x = self.a1(x)
x = self.conv2(x)
x = self.bn2(x)
return self.a2(x + self.skip(inp))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment