This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LightningResNet(pl.LightningModule): | |
def __init__(self, net_pretrained, device='cpu', criterion = F.cross_entropy, | |
num_classes = 4, optimizer = None, scheduler = None): | |
super().__init__() | |
self.net = net_pretrained | |
# set top to number of classes | |
num_ftrs = self.net.fc.in_features | |
self.net.fc = nn.Linear(num_ftrs, num_classes) | |
self.criterion = criterion | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
def forward(self, x): | |
return self.net(x) | |
@torch.no_grad() | |
def accuracy(self, outputs, labels): | |
_, preds = torch.max(outputs, dim=1) | |
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) | |
def training_step(self, batch, batch_idx): | |
loss, acc = self._shared_eval_step(batch, batch_idx) | |
metrics = {'train_loss': loss, 'train_accuracy': acc} | |
self.log_dict(metrics) | |
return {'loss': loss, "train_accuracy" : acc} | |
def test_step(self, batch, batch_idx): | |
with torch.no_grad(): | |
loss, acc = self._shared_eval_step(batch, batch_idx) | |
metrics = {"test_acc": acc, "test_loss": loss} | |
self.log_dict(metrics) | |
return metrics | |
def validation_step(self, batch, batch_idx): | |
with torch.no_grad(): | |
loss, acc = self._shared_eval_step(batch, batch_idx) | |
metrics = {'val_loss': loss, 'val_accuracy': acc} | |
self.log_dict(metrics, prog_bar=True) | |
return metrics | |
def _shared_eval_step(self, batch, batch_idx): | |
images, labels = batch | |
out = self(images) | |
loss = self.criterion(out, labels) | |
accu = self.accuracy(out,labels) | |
return loss, accu | |
def configure_optimizers(self): | |
if not self.optimizer: | |
optimizer = optim.SGD(self.net.parameters(), lr=0.001, momentum=0.9) | |
plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, verbose=True) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": plateau_scheduler, | |
"monitor": "val_accuracy" | |
} | |
} | |
else: | |
return { | |
"optimizer" : self.optimizer, | |
"lr_scheduler": { | |
"scheduler": self.scheduler, | |
"monitor": "val_loss" | |
} | |
} | |
# load data loaders and parameters | |
# ... | |
# get pretrained | |
net = models.resnet18(pretrained=True) | |
# create and load the optimizers | |
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum) | |
plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=plateau_lr_decrease, verbose=True) | |
clf_resnet18 = LightningResNet(net, num_classes=len(classes), device=device, optimizer = optimizer, scheduler=plateau_scheduler) | |
# prepare callbacks | |
callbacks = [pl.callbacks.EarlyStopping("val_accuracy", mode='max', patience=patience)] | |
logger = pl.loggers.CSVLogger("logs", name="resnet18-a1-transfer-logs", version=0) | |
# prepare trainer and launch! | |
trainer_resnet18 = pl.Trainer(accelerator="auto", gpus=1, callbacks=callbacks, max_epochs=max_epochs, log_every_n_steps=1, logger=logger) | |
trainer_resnet18.fit(model=clf_resnet18, train_dataloaders=dataloaders['train'], val_dataloaders = dataloaders['val']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment