Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active April 10, 2022 16:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/43787b2eef3c2a358f624f60682fe063 to your computer and use it in GitHub Desktop.
Save krsnewwave/43787b2eef3c2a358f624f60682fe063 to your computer and use it in GitHub Desktop.
class LightningResNet(pl.LightningModule):
def __init__(self, net_pretrained, device='cpu', criterion = F.cross_entropy,
num_classes = 4, optimizer = None, scheduler = None):
super().__init__()
self.net = net_pretrained
# set top to number of classes
num_ftrs = self.net.fc.in_features
self.net.fc = nn.Linear(num_ftrs, num_classes)
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = scheduler
def forward(self, x):
return self.net(x)
@torch.no_grad()
def accuracy(self, outputs, labels):
_, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
def training_step(self, batch, batch_idx):
loss, acc = self._shared_eval_step(batch, batch_idx)
metrics = {'train_loss': loss, 'train_accuracy': acc}
self.log_dict(metrics)
return {'loss': loss, "train_accuracy" : acc}
def test_step(self, batch, batch_idx):
with torch.no_grad():
loss, acc = self._shared_eval_step(batch, batch_idx)
metrics = {"test_acc": acc, "test_loss": loss}
self.log_dict(metrics)
return metrics
def validation_step(self, batch, batch_idx):
with torch.no_grad():
loss, acc = self._shared_eval_step(batch, batch_idx)
metrics = {'val_loss': loss, 'val_accuracy': acc}
self.log_dict(metrics, prog_bar=True)
return metrics
def _shared_eval_step(self, batch, batch_idx):
images, labels = batch
out = self(images)
loss = self.criterion(out, labels)
accu = self.accuracy(out,labels)
return loss, accu
def configure_optimizers(self):
if not self.optimizer:
optimizer = optim.SGD(self.net.parameters(), lr=0.001, momentum=0.9)
plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": plateau_scheduler,
"monitor": "val_accuracy"
}
}
else:
return {
"optimizer" : self.optimizer,
"lr_scheduler": {
"scheduler": self.scheduler,
"monitor": "val_loss"
}
}
# load data loaders and parameters
# ...
# get pretrained
net = models.resnet18(pretrained=True)
# create and load the optimizers
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=plateau_lr_decrease, verbose=True)
clf_resnet18 = LightningResNet(net, num_classes=len(classes), device=device, optimizer = optimizer, scheduler=plateau_scheduler)
# prepare callbacks
callbacks = [pl.callbacks.EarlyStopping("val_accuracy", mode='max', patience=patience)]
logger = pl.loggers.CSVLogger("logs", name="resnet18-a1-transfer-logs", version=0)
# prepare trainer and launch!
trainer_resnet18 = pl.Trainer(accelerator="auto", gpus=1, callbacks=callbacks, max_epochs=max_epochs, log_every_n_steps=1, logger=logger)
trainer_resnet18.fit(model=clf_resnet18, train_dataloaders=dataloaders['train'], val_dataloaders = dataloaders['val'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment