Skip to content

Instantly share code, notes, and snippets.

@cosmic-cortex
Created June 3, 2020 11:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cosmic-cortex/fd8f9bf4707918a4489775515e54de0b to your computer and use it in GitHub Desktop.
Save cosmic-cortex/fd8f9bf4707918a4489775515e54de0b to your computer and use it in GitHub Desktop.
import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.models import resnet18
transform = transforms.Compose(
[transforms.Resize(224), transforms.Grayscale(3), transforms.ToTensor()]
)
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
train_loader = DataLoader(dataset, batch_size=32)
class ResNetModel(LightningModule):
def __init__(self):
super().__init__()
self.model = resnet18(pretrained=False, num_classes=10)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {"train_loss": loss}
return {"loss": loss, "log": tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
model = ResNetModel()
trainer = Trainer(num_nodes=1, max_epochs=50)
trainer.fit(model, train_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment