Skip to content

Instantly share code, notes, and snippets.

@theja-vanka
Created May 18, 2025 20:12
Show Gist options
  • Save theja-vanka/fb5d60b6dbe17c2ac87590178efb32ae to your computer and use it in GitHub Desktop.
Save theja-vanka/fb5d60b6dbe17c2ac87590178efb32ae to your computer and use it in GitHub Desktop.
lightning framework
import cv2
import torch
import pandas as pd
import lightning as L
import multiprocessing
from pathlib import Path
import albumentations as A
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from albumentations.pytorch import ToTensorV2
class CSVDataSet(Dataset):
def __init__(self, df, root_path, file_column, label_column, transforms=None):
super().__init__()
self.df = pd.read_csv(df)
self.transforms = transforms
if root_path:
self.root_path = Path(root_path)
else:
self.root_path = None
self.file_column = file_column
self.label_column = label_column
def __getitem__(self, i):
if self.root_path:
image_path = self.root_path / self.df.iloc[i][self.file_column]
else:
image_path = self.df.iloc[i][self.file_column]
sample = cv2.imread(image_path)
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)
label = eval(self.df.iloc[i][self.label_column])
return sample, label
def __len__(self):
return len(self.df)
def _collate_fn(self, batch):
imgs, classes = list(zip(*batch))
if self.transforms:
imgs = [self.transforms(image=img)["image"][None] for img in imgs]
classes = [torch.tensor([clss]) for clss in classes]
imgs, classes = [torch.cat(i) for i in [imgs, classes]]
return imgs, classes
class CSVInferenceDataSet(Dataset):
def __init__(self, df, root_path, file_column, label_column, transforms=None):
super().__init__()
self.df = pd.read_csv(df)
self.transforms = transforms
if root_path:
self.root_path = Path(root_path)
else:
self.root_path = None
self.file_column = file_column
self.label_column = label_column
def __getitem__(self, i):
if self.root_path:
image_path = self.root_path / self.df.iloc[i][self.file_column]
else:
image_path = self.df.iloc[i][self.file_column]
sample = cv2.imread(image_path)
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)
label = eval(self.df.iloc[i][self.label_column])
return str(image_path), sample, label
def __len__(self):
return len(self.df)
def _collate_fn(self, batch):
filename, imgs, classes = list(zip(*batch))
if self.transforms:
imgs = [self.transforms(image=img)["image"][None] for img in imgs]
classes = [torch.tensor([clss]) for clss in classes]
imgs, classes = [torch.cat(i) for i in [imgs, classes]]
filename = [str(f) for f in filename]
return filename, imgs, classes
def generate_pseudo_set(df, root_path, file_column, label_column, batch_size):
transform = A.Compose(
[
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
test_dataset = CSVInferenceDataSet(df, root_path, file_column, label_column, transform)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=test_dataset._collate_fn,
num_workers=multiprocessing.cpu_count() - 1,
pin_memory=True,
)
return test_loader
class CSVMineDataSet(Dataset):
def __init__(self, df, transforms=None, fast_dev_run=True):
super().__init__()
self.df = pd.read_csv(df)
if fast_dev_run:
self.df = self.df.head(10000)
self.transforms = transforms
def __getitem__(self, i):
image_path = self.df.iloc[i].image_key
sample = cv2.imread(image_path)
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)
return str(image_path), sample
def __len__(self):
return len(self.df)
def _collate_fn(self, batch):
filename, imgs = list(zip(*batch))
if self.transforms:
imgs = [self.transforms(image=img)["image"][None] for img in imgs]
imgs = torch.cat(imgs)
filename = [str(f) for f in filename]
return filename, imgs
def generate_mine_set(filename, batch_size, fast_dev_run=1):
transform = A.Compose(
[
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
fast_dev_run = True if fast_dev_run == 1 else False
mine_dataset = CSVMineDataSet(filename, transform, fast_dev_run)
mine_loader = DataLoader(
mine_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=mine_dataset._collate_fn,
num_workers=multiprocessing.cpu_count() - 1,
pin_memory=True,
)
return mine_loader
class TFFCDataModule(L.LightningDataModule):
def __init__(self, config, batch_size=None):
super().__init__()
self.config = config
if batch_size is None:
self.batch_size = self.config["batch_size"]
else:
self.batch_size = batch_size
def prepare_data(self):
pass
def setup(self, stage: str):
self.transform = A.Compose(
[
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
self.train_transform = A.Compose(
[
A.Resize(224, 224),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]
)
self.train_dataset = CSVDataSet(
self.config["train_df"], self.config["root_path"], self.config["file_column"], self.config["label_column"], self.train_transform
)
self.val_dataset = CSVDataSet(
self.config["val_df"], self.config["root_path"], self.config["file_column"], self.config["label_column"], self.transform
)
self.test_dataset = CSVDataSet(
self.config["test_df"], self.config["root_path"], self.config["file_column"], self.config["label_column"], self.transform
)
# multiprocessing.cpu_count() - 1
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=self.train_dataset._collate_fn,
num_workers=multiprocessing.cpu_count() - 1,
pin_memory=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.val_dataset._collate_fn,
num_workers=multiprocessing.cpu_count() - 1,
pin_memory=True,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
collate_fn=self.test_dataset._collate_fn,
num_workers=multiprocessing.cpu_count() - 1,
pin_memory=True,
)
if __name__ == "__main__":
pass
import torch
import torch.nn as nn
import lightning.pytorch as L
from datetime import datetime
from torchvision import models
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.classification import MultilabelAccuracy
class Config:
CRITERION = nn.BCEWithLogitsLoss()
OPTIMIZER = torch.optim.AdamW
SCHEDULER = CosineAnnealingLR
class ObstructionDetectionModel(nn.Module):
def __init__(self, config):
super(ObstructionDetectionModel, self).__init__()
self.config = config
self.model = models.mobilenet_v2(weights="IMAGENET1K_V2")
self.model.classifier[1] = nn.Linear(
self.model.classifier[1].in_features, self.config["classes"]
)
def forward(self, inputs):
outputs = self.model(inputs)
return outputs
class LightningModel(L.LightningModule):
def __init__(self, config, learning_rate=0.00001):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.config = config
self.criterion = Config.CRITERION
self.optimizer = Config.OPTIMIZER
self.scheduler = Config.SCHEDULER
self.model = ObstructionDetectionModel(config)
self.model = torch.compile(self.model)
self.train_acc = MultilabelAccuracy(num_labels=config["classes"])
self.val_acc = MultilabelAccuracy(num_labels=config["classes"])
self.test_acc = MultilabelAccuracy(num_labels=config["classes"])
def forward(self, x):
return self.model(x)
def _shared_step(self, batch):
sample, target = batch
logits = self(sample)
loss = self.criterion(logits, target.float())
return loss, target, logits
def on_train_epoch_end(self):
_train_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
train_ts_date = int(_train_timestamp.split("_")[0])
train_ts_time = int(_train_timestamp.split("_")[1])
self.log("train_ts_date", train_ts_date)
self.log("train_ts_time", train_ts_time)
def on_validation_epoch_end(self):
_val_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
val_ts_date = int(_val_timestamp.split("_")[0])
val_ts_time = int(_val_timestamp.split("_")[1])
self.log("val_ts_date", val_ts_date)
self.log("val_ts_time", val_ts_time)
def training_step(self, batch, batch_idx):
loss, true_labels, logits = self._shared_step(batch)
self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
self.train_acc(logits, true_labels)
self.log(
"train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False
)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, logits = self._shared_step(batch)
self.log("val_loss", loss, prog_bar=True)
self.val_acc(logits, true_labels)
self.log("val_acc", self.val_acc, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, logits = self._shared_step(batch)
self.test_acc(logits, true_labels)
self.log("test_acc", self.test_acc)
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.learning_rate)
scheduler = self.scheduler(optimizer, T_max=5)
return [optimizer], [scheduler]
if __name__ == "__main__":
pass
import os
import time
import torch
import mlflow
from watermark import watermark
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import LearningRateFinder
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from src.data import TFFCDataModule
from src.models import LightningModel
os.environ["LOGNAME"] = "krishnatheja.vanka"
torch.set_float32_matmul_precision("high")
class Config:
MODEL_NAME = "class3"
def cli_main():
version_no = time.strftime("%Y%m%d_%H%M%S", time.gmtime(int(float(time.time()))))
model_name = Config.MODEL_NAME
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath=f"artifacts/iter_{version_no}_{model_name}/",
filename="RC_{val_ts_date:.12g}_{val_ts_time:g}_ep{epoch:03d}",
mode="min",
save_top_k=1,
auto_insert_metric_name=False,
save_on_train_epoch_end=False,
)
mlflow.set_experiment(experiment_name="tffc-road-condition-multi-label")
mlflow.start_run(run_name=f"{version_no}")
mlflow.pytorch.autolog()
cli = LightningCLI(
LightningModel,
TFFCDataModule,
seed_everything_default=42,
run=False,
trainer_defaults={
"fast_dev_run": False,
"callbacks": [
checkpoint_callback,
RichProgressBar(),
LearningRateFinder(),
LearningRateMonitor(logging_interval="epoch"),
EarlyStopping(monitor="train_acc", mode="max"),
],
},
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
mlflow.end_run()
if __name__ == "__main__":
# python train.py --config conf/masterconfig.yaml
print(watermark(packages="torch,lightning", python=True))
print("Torch CUDA available?", torch.cuda.is_available())
cli_main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment