Skip to content

Instantly share code, notes, and snippets.

@nunenuh
Created July 27, 2020 18:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nunenuh/b4b3150c991da00a27c98236c2988ca9 to your computer and use it in GitHub Desktop.
Save nunenuh/b4b3150c991da00a27c98236c2988ca9 to your computer and use it in GitHub Desktop.
# build your model
class StandardMNIST(nn.Module):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer1 = torch.nn.Linear(28 * 28, 128)
self.layer2 = torch.nn.Linear(128, 256)
self.layer3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
x = torch.relu(x)
x = self.layer3(x)
x = torch.log_softmax(x, dim=1)
return x
# extend StandardMNIST and LightningModule at the same time
# this is what I like from python, extend two class at the same time
class ExtendMNIST(StandardMNIST, LightningModule):
def __init__(self):
super().__init__()
def training_step(self, batch, batch_idx):
data, target = batch
logits = self.forward(data)
loss = F.nll_loss(logits, target)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# run the training
model = ExtendMNIST()
trainer = Trainer(max_epochs=5, gpus=1)
trainer.fit(model, mnist_train_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment