Skip to content

Instantly share code, notes, and snippets.

@Niranjankumar-c
Created October 20, 2019 13:53
Show Gist options
  • Save Niranjankumar-c/e3801007050850bda954213a06d69df4 to your computer and use it in GitHub Desktop.
Save Niranjankumar-c/e3801007050850bda954213a06d69df4 to your computer and use it in GitHub Desktop.
network to visualize the batch norm
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(784, 48), # 28 x 28 = 784 flatten the input image
nn.ReLU(),
nn.Linear(48, 24),
nn.ReLU(),
nn.Linear(24, 10)
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class MyNetBN(nn.Module):
def __init__(self):
super(MyNetBN, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(784, 48),
nn.BatchNorm1d(48), #applying batch norm
nn.ReLU(),
nn.Linear(48, 24),
nn.BatchNorm1d(24),
nn.ReLU(),
nn.Linear(24, 10)
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment