Skip to content

Instantly share code, notes, and snippets.

@Niranjankumar-c
Created October 20, 2019 16:31
Show Gist options
  • Save Niranjankumar-c/115d9ddf6e2b4d7afa5f2c35797597cc to your computer and use it in GitHub Desktop.
Save Niranjankumar-c/115d9ddf6e2b4d7afa5f2c35797597cc to your computer and use it in GitHub Desktop.
batch norm 2d for visualizing batch norm
class CNN_BN(nn.Module):
def __init__(self):
super(MyNetBN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 3, 5), # (N, 1, 28, 28) -> (N, 3, 24, 24)
nn.ReLU(),
nn.AvgPool2d(2, stride=2), # (N, 3, 24, 24) -> (N, 3, 12, 12)
nn.Conv2d(3, 6, 3),
nn.BatchNorm2d(6) # (N, 3, 12, 12) -> (N, 6, 10, 10)
)
self.features1 = nn.Sequential(
nn.ReLU(),
nn.AvgPool2d(2, stride=2) # (N, 6, 10, 10) -> (N, 6, 5, 5)
)
self.classifier = nn.Sequential(
nn.Linear(150, 25), # (N, 150) -> (N, 25)
nn.ReLU(),
nn.Linear(25,10) # (N, 25) -> (N, 10)
)
def forward(self, x):
x = self.features(x)
x = self.features1(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment