Skip to content

Instantly share code, notes, and snippets.

@AnshKetchum
Last active February 5, 2024 06:02
Show Gist options
  • Save AnshKetchum/dc0bcca61ff6fd2a6e05839e2c892926 to your computer and use it in GitHub Desktop.
Save AnshKetchum/dc0bcca61ff6fd2a6e05839e2c892926 to your computer and use it in GitHub Desktop.
General Purpose Vision Pipeline
'''
General goals of this code:
1 - Create an adaptable, general pipeline that enables devs to be able to understand what the
convnet is seeing (through class activation maps like GradCAM)
2 - Ensure the pipeline is **configurable**
The pipeline will leverage torch + torchvision + lightning + mlflow (logging)
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.callbacks import EarlyStopping
import mlflow
import mlflow.pytorch
import torchvision
import torchvision.transforms as T
# Grad-CAM imports -- useful for visualizing CNNs and Transformers
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Cool progress bars
from tqdm import tqdm
import numpy as np
import random
# Image manipulations
from PIL import Image
# Setup a basic transformation
# We'll keep it simple, for now
train_transforms = T.Compose([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
T.RandomAutocontrast(),
T.ToTensor(),
])
transforms = T.Compose([
T.ToTensor()
])
# For this example, we'll use torchvision's CIFAR-10 example
train_dataset = torchvision.datasets.CIFAR10(
'./data', train=True, download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=16, num_workers=7)
# For this example, we'll use torchvision's CIFAR-10 example
val_dataset = torchvision.datasets.CIFAR10(
'./data', train=False, download=True, transform=transforms)
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=7)
# Define a model - this is a model submitted for a Berkeley CS 189 assignment, with a target of >75% val acc
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.mp = nn.MaxPool2d(2, 2)
self.flat = nn.Flatten()
self.fc1 = nn.Linear(256*4*4, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 10)
def forward(self, x):
# First block, VGG inspired
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.bn1(self.mp(x))
# Second block, VGG inspired
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = self.bn2(self.mp(x))
# Third block, VGG inspired
x = F.relu(self.conv5(x))
x = F.relu(self.conv6(x))
x = self.bn3(self.mp(x))
x = self.flat(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class LightningConv(L.LightningModule):
def __init__(self, rand_num=100) -> None:
super().__init__()
self.net = ConvNet()
self.rand_num = rand_num
self.i = 0
self.losses = []
self.acc = 0
self.n = 0
def forward(self, x):
# First block, VGG inspired
x = self.net(x)
return x
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters())
return optim
def on_train_epoch_start(self):
self.rand_idx = random.randint(0, self.rand_num)
self.i += 1
self.losses = []
self.acc = 0
self.n = 0
def on_validation_epoch_start(self):
self.losses = []
self.acc = 0
self.n = 0
def on_train_epoch_end(self):
mlflow.log_metric("train_total_loss", sum(self.losses), step=self.i)
mlflow.log_metric("train_average_loss", sum(
self.losses) / len(self.losses), step=self.i)
mlflow.log_metric("train_acc", self.acc / self.n, step=self.i)
def on_validation_epoch_end(self):
mlflow.log_metric("val_total_loss", sum(self.losses), step=self.i)
mlflow.log_metric("val_average_loss", sum(
self.losses) / self.n, step=self.i)
mlflow.log_metric("val_acc", self.acc / self.n, step=self.i)
if self.i % 100 == 0:
mlflow.pytorch.log_model(self.net, f"model_epoch_{self.i}")
def training_step(self, batch, batch_idx):
X, y = batch
pred = self.forward(X)
loss = F.cross_entropy(pred, y)
rand_img = random.randint(0, X.shape[0] - 1)
if self.rand_idx == batch_idx:
vis_image = visualize_with_cam(X[rand_img], y[rand_img], self.net)
mlflow.log_image(
vis_image, f'train_epoch_{self.i}_{batch_idx}_{rand_img}.jpg')
self.acc += (torch.argmax(pred, dim=1) == y).sum().item()
self.losses.append(loss.item())
self.n += X.shape[0]
return loss
def validation_step(self, batch, batch_idx):
X, y = batch
pred = self.forward(X)
loss = F.cross_entropy(pred, y)
self.acc += (torch.argmax(pred, dim=1) == y).sum().item()
self.losses.append(loss.item())
self.n += X.shape[0]
return loss
def visualize_with_cam(X, y, net: torch.nn.Module):
target_layers = [net.conv6]
label_int: int = y.item()
explainability_targets = [ClassifierOutputTarget(label_int)]
with GradCAM(model=net, target_layers=target_layers) as cam:
grayscale_cam = cam(input_tensor=X.unsqueeze(0),
targets=explainability_targets)[0, :]
img: Image = torchvision.transforms.ToPILImage()(X)
numpy_image = np.asarray(img) / 255.0
vis_image = show_cam_on_image(
numpy_image, grayscale_cam, use_rgb=True)
return Image.fromarray(vis_image).resize((500, 500))
# View the training run live by typing 'mlflow ui'
experiment = mlflow.set_experiment("cifar_gradcam")
with mlflow.start_run(run_name="lightning_run", experiment_id=experiment.experiment_id):
conv_lightning_model = LightningConv(rand_num=100)
# [EarlyStopping(monitor="val_accuracy", mode="min", patience=3)]
callbacks = []
trainer = L.Trainer(accelerator="cuda",
max_epochs=1000, callbacks=callbacks)
# Fit the lightning module to the dataset
trainer.fit(conv_lightning_model, train_loader, val_loader)
trainer.save_checkpoint('final', weights_only=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment