Skip to content

Instantly share code, notes, and snippets.

@emilemathieu
Last active August 9, 2018 15:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save emilemathieu/acb3dc9a0125bd9e30b69ff857cc3eec to your computer and use it in GitHub Desktop.
Save emilemathieu/acb3dc9a0125bd9e30b69ff857cc3eec to your computer and use it in GitHub Desktop.
class MyNet(nn.Module):
def __init__(self):
self.features = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(2, 2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(2, 2),
nn.ReLU()
)
self.flatten = nn.Flatten()
self.classifier = nn.Sequential(
nn.Linear(320, 120),
nn.ReLU(),
nn.Linear(120, 10),
nn.ReLU()
)
def forward(self, x):
x = self.features(x)
x = self.flatten(x)
x = self.classifier(x)
return x.reshape(x.shape[0], -1)
def backward(self, output_grad):
output_grad = self.classifier.backward(output_grad)
output_grad = self.flatten.backward(output_grad)
return self.features.backward(output_grad)
def step(self, optimizer):
self.classifier.step(optimizer)
self.features.step(optimizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment