Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/ae9850499e199ffb804a785684032db7 to your computer and use it in GitHub Desktop.
Save SannaPersson/ae9850499e199ffb804a785684032db7 to your computer and use it in GitHub Desktop.
class ResidualBlock(nn.Module):
def __init__(self, channels, use_residual=True, num_repeats=1):
super(ResidualBlock, self).__init__()
self.layers = nn.ModuleList()
for repeat in range(num_repeats):
self.layers += [
nn.Sequential(
CNNBlock(channels, channels // 2, kernel_size=1),
CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
)
]
self.use_residual = use_residual
self.num_repeats = num_repeats
def forward(self, x):
for layer in self.layers:
x = layer(x) + self.use_residual * x
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment