Skip to content

Instantly share code, notes, and snippets.

@nbertagnolli
Created January 1, 2022 21:14
Show Gist options
  • Save nbertagnolli/13e655efdeb9daa19fdec174d21b8340 to your computer and use it in GitHub Desktop.
Save nbertagnolli/13e655efdeb9daa19fdec174d21b8340 to your computer and use it in GitHub Desktop.
A simple mlp mnist model using custom dropout class
class MNISTModel(torch.nn.Module):
def __init__(self):
super(MNISTModel, self).__init__()
self.layer_1 = nn.Linear(28 * 28, 512)
self.layer_2 = nn.Linear(512, 512)
self.layer_3 = nn.Linear(512, 10)
self.dropout = Dropout(.5)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.layer_1(x)
x = self.layer_2(x)
x = self.dropout(x)
output = self.layer_3(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment