-
-
Save aramakus/ab6e8b8da532c36d2cdf3dda9e2f3cac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Residual(nn.Module): | |
""" | |
Residual module from "Deep Residual Learning for Image Recognition" | |
(https://arxiv.org/pdf/1512.03385.pdf) | |
""" | |
def __init__(self, inp_C, out_C, stride = 1, batch_norm=True): | |
super(Residual, self).__init__() | |
self.conv1 = nn.Conv2d(inp_C, out_C, kernel_size=3, stride=stride, padding=1) | |
self.bn1 = nn.BatchNorm2d(out_C) if batch_norm else nn.Sequential() | |
self.a1 = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv2d(out_C, out_C, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm2d(out_C) if batch_norm else nn.Sequential() | |
changed_shape = (stride != 1) or (inp_C != out_C) | |
self.skip = nn.Sequential(nn.Conv2d(inp_C, out_C, kernel_size=1, stride=stride), | |
nn.BatchNorm2d(out_C)) if changed_shape else nn.Sequential() | |
self.a2 = nn.ReLU(inplace=True) | |
def forward(self, inp): | |
x = self.conv1(inp) | |
x = self.bn1(x) | |
x = self.a1(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
return self.a2(x + self.skip(inp)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment