Skip to content

Instantly share code, notes, and snippets.

@adimyth
Created August 25, 2021 08:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adimyth/e6e8cfbca8bd77ad933cf210d675b036 to your computer and use it in GitHub Desktop.
Save adimyth/e6e8cfbca8bd77ad933cf210d675b036 to your computer and use it in GitHub Desktop.
Transfer Learning with Poisson Loss (Pytorch Lightning)
from glob import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy, auroc
from torchmetrics.functional import mean_absolute_error, mean_squared_error
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
import torchvision.models as models
from torchvision import transforms
from torchvision.datasets import ImageFolder
import wandb
# Kaggle specific way to use Wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb-key")
wandb.login(key=wandb_api)
wandb.init(project="count-the-green-boxes")
wandb_logger = WandbLogger(project="count-green-boxes-lightning", job_type="train")
# Constants
RANDOM_STATE = 42
NUM_CLASSES = 98
# Seed Everything
pl.seed_everything(RANDOM_STATE)
# DataModule
class DataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 64, data_dir: str = ""):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose(
[
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def setup(self, stage=None):
dataset = ImageFolder(self.data_dir)
num_train = int(0.8 * len(dataset))
num_valid = len(dataset) - num_train
self.train, self.val = random_split(dataset, [num_train, num_valid])
self.train.dataset.transform = self.transform
self.val.dataset.transform = self.transform
def train_dataloader(self):
return DataLoader(
self.train, batch_size=self.batch_size, shuffle=True, num_workers=2
)
def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size, num_workers=2)
# Model
class CountModel(pl.LightningModule):
def __init__(
self, input_shape, num_classes: int = 100, learning_rate: float = 2e-4
):
super().__init__()
# log hyperparameters
self.save_hyperparameters()
self.learning_rate = learning_rate
self.dim = input_shape
self.num_classes = num_classes
self.feature_extractor = models.resnet18(pretrained=True)
self.feature_extractor.eval()
for param in self.feature_extractor.parameters():
param.requires_grad = False
n_sizes = self._get_conv_output(input_shape)
self.classifier = nn.Linear(n_sizes, num_classes)
# returns the size of the output tensor going into Linear layer from the conv block.
def _get_conv_output(self, shape):
batch_size = 1
input = torch.autograd.Variable(torch.rand(batch_size, *shape))
output_feat = self._forward_features(input)
n_size = output_feat.data.view(batch_size, -1).size(1)
return n_size
# returns the feature tensor from the conv block
def _forward_features(self, x):
x = self.feature_extractor(x)
return x
# will be used during inference
def forward(self, x):
x = self._forward_features(x)
x = x.view(x.size(0), -1)
x = F.log_softmax(self.classifier(x), dim=1)
return x
# logic for a single training step
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
loss = F.poisson_nll_loss(preds, y, log_input=False)
loss.requires_grad = True
# training metrics
acc = accuracy(preds, y)
mae = mean_absolute_error(preds, y)
mse = mean_squared_error(preds, y)
self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
self.log("train_acc", acc, on_step=True, on_epoch=True, logger=True)
self.log("train_mae", mae, on_step=True, on_epoch=True, logger=True)
self.log("train_mse", mse, on_step=True, on_epoch=True, logger=True)
return loss
# logic for a single validation step
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
loss = F.poisson_nll_loss(preds, y, log_input=False)
# validation metrics
acc = accuracy(preds, y)
mae = mean_absolute_error(preds, y)
mse = mean_squared_error(preds, y)
self.log("val_loss", loss, on_step=True, on_epoch=True, logger=True)
self.log("val_acc", acc, on_step=True, on_epoch=True, logger=True)
self.log("val_mae", mae, on_step=True, on_epoch=True, logger=True)
self.log("val_mse", mse, on_step=True, on_epoch=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
# Test Dataset - because ImageFolder doesn't work with missing folder label
class TestDataset(torch.utils.data.Dataset):
def __init__(self, main_dir: str = ""):
self.main_dir = main_dir
self.transform = transforms.Compose(
[
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
self.total_imgs = sorted(glob(f"{main_dir}/*.png"))
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
img_loc = self.total_imgs[idx]
image = Image.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
return tensor_image
if __name__ == "__main__":
# Data Setup
datamodule = DataModule(
batch_size=1024, data_dir="../input/count-the-blue-boxes/train/train/"
)
datamodule.setup()
# Callbacks
early_stop_callback = EarlyStopping(
monitor="val_loss", patience=3, verbose=False, mode="min"
)
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
filename="model-{epoch:02d}-{val_loss:.2f}",
save_top_k=3,
mode="min",
)
# Training
model = CountModel((3, 64, 64), NUM_CLASSES)
trainer = pl.Trainer(
max_epochs=1,
progress_bar_refresh_rate=5,
gpus=1,
callbacks=[early_stop_callback, checkpoint_callback],
)
trainer.fit(model, datamodule)
# Inference
# load the best model - model with lowest validation loss
model_ckpts = sorted(glob("lightning_logs/*/checkpoints/*.ckpt"))
losses = []
for model_ckpt in model_ckpts:
loss = re.findall("\d+\.\d+", model_ckpt)
losses.append(float(loss[0]))
losses = np.array(losses)
best_model_index = np.argsort(losses)[0]
best_model = model_ckpts[best_model_index]
print(f"Best Model: {best_model}")
inference_model = CountModel.load_from_checkpoint(best_model)
test_dataset = TestDataset("../input/count-the-blue-boxes/test/test/")
test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=1024, num_workers=2
)
print(f"Test Dataset: {len(test_dataset)}\tTest DataLoader: {len(test_dataloader)}")
y_pred = []
for imgs in test_dataloader:
logits = inference_model(imgs)
preds = torch.argmax(logits, dim=1)
y_pred.extend(preds.detach().numpy())
all_imgs = natsort.natsorted(os.listdir(main_dir))
submission = pd.DataFrame.from_dict({"images": all_imgs, "labels": y_pred})
submission.to_csv("submission.csv", index=False)
wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment