Last active
August 9, 2018 15:23
-
-
Save emilemathieu/acb3dc9a0125bd9e30b69ff857cc3eec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyNet(nn.Module): | |
def __init__(self): | |
self.features = nn.Sequential( | |
nn.Conv2d(1, 10, kernel_size=5), | |
nn.MaxPool2d(2, 2), | |
nn.ReLU(), | |
nn.Conv2d(10, 20, kernel_size=5), | |
nn.MaxPool2d(2, 2), | |
nn.ReLU() | |
) | |
self.flatten = nn.Flatten() | |
self.classifier = nn.Sequential( | |
nn.Linear(320, 120), | |
nn.ReLU(), | |
nn.Linear(120, 10), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.flatten(x) | |
x = self.classifier(x) | |
return x.reshape(x.shape[0], -1) | |
def backward(self, output_grad): | |
output_grad = self.classifier.backward(output_grad) | |
output_grad = self.flatten.backward(output_grad) | |
return self.features.backward(output_grad) | |
def step(self, optimizer): | |
self.classifier.step(optimizer) | |
self.features.step(optimizer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment