Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/78e88914f22e0fda3963f13536cff74e to your computer and use it in GitHub Desktop.
Save SannaPersson/78e88914f22e0fda3963f13536cff74e to your computer and use it in GitHub Desktop.
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
super(CNNBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.leaky = nn.LeakyReLU(0.1)
self.use_bn_act = bn_act
def forward(self, x):
if self.use_bn_act:
return self.leaky(self.bn(self.conv(x)))
else:
return self.conv(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment