Last active
February 5, 2024 06:02
-
-
Save AnshKetchum/dc0bcca61ff6fd2a6e05839e2c892926 to your computer and use it in GitHub Desktop.
General Purpose Vision Pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
General goals of this code: | |
1 - Create an adaptable, general pipeline that enables devs to be able to understand what the | |
convnet is seeing (through class activation maps like GradCAM) | |
2 - Ensure the pipeline is **configurable** | |
The pipeline will leverage torch + torchvision + lightning + mlflow (logging) | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import lightning as L | |
from lightning.pytorch.loggers import MLFlowLogger | |
from lightning.pytorch.callbacks import EarlyStopping | |
import mlflow | |
import mlflow.pytorch | |
import torchvision | |
import torchvision.transforms as T | |
# Grad-CAM imports -- useful for visualizing CNNs and Transformers | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
# Cool progress bars | |
from tqdm import tqdm | |
import numpy as np | |
import random | |
# Image manipulations | |
from PIL import Image | |
# Setup a basic transformation | |
# We'll keep it simple, for now | |
train_transforms = T.Compose([ | |
T.RandomHorizontalFlip(), | |
T.RandomVerticalFlip(), | |
T.RandomAutocontrast(), | |
T.ToTensor(), | |
]) | |
transforms = T.Compose([ | |
T.ToTensor() | |
]) | |
# For this example, we'll use torchvision's CIFAR-10 example | |
train_dataset = torchvision.datasets.CIFAR10( | |
'./data', train=True, download=True, transform=train_transforms) | |
train_loader = DataLoader(train_dataset, batch_size=16, num_workers=7) | |
# For this example, we'll use torchvision's CIFAR-10 example | |
val_dataset = torchvision.datasets.CIFAR10( | |
'./data', train=False, download=True, transform=transforms) | |
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=7) | |
# Define a model - this is a model submitted for a Berkeley CS 189 assignment, with a target of >75% val acc | |
class ConvNet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) | |
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) | |
self.bn2 = nn.BatchNorm2d(128) | |
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) | |
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) | |
self.bn3 = nn.BatchNorm2d(256) | |
self.mp = nn.MaxPool2d(2, 2) | |
self.flat = nn.Flatten() | |
self.fc1 = nn.Linear(256*4*4, 1024) | |
self.fc2 = nn.Linear(1024, 512) | |
self.fc3 = nn.Linear(512, 10) | |
def forward(self, x): | |
# First block, VGG inspired | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = self.bn1(self.mp(x)) | |
# Second block, VGG inspired | |
x = F.relu(self.conv3(x)) | |
x = F.relu(self.conv4(x)) | |
x = self.bn2(self.mp(x)) | |
# Third block, VGG inspired | |
x = F.relu(self.conv5(x)) | |
x = F.relu(self.conv6(x)) | |
x = self.bn3(self.mp(x)) | |
x = self.flat(x) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
class LightningConv(L.LightningModule): | |
def __init__(self, rand_num=100) -> None: | |
super().__init__() | |
self.net = ConvNet() | |
self.rand_num = rand_num | |
self.i = 0 | |
self.losses = [] | |
self.acc = 0 | |
self.n = 0 | |
def forward(self, x): | |
# First block, VGG inspired | |
x = self.net(x) | |
return x | |
def configure_optimizers(self): | |
optim = torch.optim.Adam(self.parameters()) | |
return optim | |
def on_train_epoch_start(self): | |
self.rand_idx = random.randint(0, self.rand_num) | |
self.i += 1 | |
self.losses = [] | |
self.acc = 0 | |
self.n = 0 | |
def on_validation_epoch_start(self): | |
self.losses = [] | |
self.acc = 0 | |
self.n = 0 | |
def on_train_epoch_end(self): | |
mlflow.log_metric("train_total_loss", sum(self.losses), step=self.i) | |
mlflow.log_metric("train_average_loss", sum( | |
self.losses) / len(self.losses), step=self.i) | |
mlflow.log_metric("train_acc", self.acc / self.n, step=self.i) | |
def on_validation_epoch_end(self): | |
mlflow.log_metric("val_total_loss", sum(self.losses), step=self.i) | |
mlflow.log_metric("val_average_loss", sum( | |
self.losses) / self.n, step=self.i) | |
mlflow.log_metric("val_acc", self.acc / self.n, step=self.i) | |
if self.i % 100 == 0: | |
mlflow.pytorch.log_model(self.net, f"model_epoch_{self.i}") | |
def training_step(self, batch, batch_idx): | |
X, y = batch | |
pred = self.forward(X) | |
loss = F.cross_entropy(pred, y) | |
rand_img = random.randint(0, X.shape[0] - 1) | |
if self.rand_idx == batch_idx: | |
vis_image = visualize_with_cam(X[rand_img], y[rand_img], self.net) | |
mlflow.log_image( | |
vis_image, f'train_epoch_{self.i}_{batch_idx}_{rand_img}.jpg') | |
self.acc += (torch.argmax(pred, dim=1) == y).sum().item() | |
self.losses.append(loss.item()) | |
self.n += X.shape[0] | |
return loss | |
def validation_step(self, batch, batch_idx): | |
X, y = batch | |
pred = self.forward(X) | |
loss = F.cross_entropy(pred, y) | |
self.acc += (torch.argmax(pred, dim=1) == y).sum().item() | |
self.losses.append(loss.item()) | |
self.n += X.shape[0] | |
return loss | |
def visualize_with_cam(X, y, net: torch.nn.Module): | |
target_layers = [net.conv6] | |
label_int: int = y.item() | |
explainability_targets = [ClassifierOutputTarget(label_int)] | |
with GradCAM(model=net, target_layers=target_layers) as cam: | |
grayscale_cam = cam(input_tensor=X.unsqueeze(0), | |
targets=explainability_targets)[0, :] | |
img: Image = torchvision.transforms.ToPILImage()(X) | |
numpy_image = np.asarray(img) / 255.0 | |
vis_image = show_cam_on_image( | |
numpy_image, grayscale_cam, use_rgb=True) | |
return Image.fromarray(vis_image).resize((500, 500)) | |
# View the training run live by typing 'mlflow ui' | |
experiment = mlflow.set_experiment("cifar_gradcam") | |
with mlflow.start_run(run_name="lightning_run", experiment_id=experiment.experiment_id): | |
conv_lightning_model = LightningConv(rand_num=100) | |
# [EarlyStopping(monitor="val_accuracy", mode="min", patience=3)] | |
callbacks = [] | |
trainer = L.Trainer(accelerator="cuda", | |
max_epochs=1000, callbacks=callbacks) | |
# Fit the lightning module to the dataset | |
trainer.fit(conv_lightning_model, train_loader, val_loader) | |
trainer.save_checkpoint('final', weights_only=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment