Created
May 18, 2025 20:12
-
-
Save theja-vanka/fb5d60b6dbe17c2ac87590178efb32ae to your computer and use it in GitHub Desktop.
lightning framework
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import torch | |
import pandas as pd | |
import lightning as L | |
import multiprocessing | |
from pathlib import Path | |
import albumentations as A | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
from albumentations.pytorch import ToTensorV2 | |
class CSVDataSet(Dataset): | |
def __init__(self, df, root_path, file_column, label_column, transforms=None): | |
super().__init__() | |
self.df = pd.read_csv(df) | |
self.transforms = transforms | |
if root_path: | |
self.root_path = Path(root_path) | |
else: | |
self.root_path = None | |
self.file_column = file_column | |
self.label_column = label_column | |
def __getitem__(self, i): | |
if self.root_path: | |
image_path = self.root_path / self.df.iloc[i][self.file_column] | |
else: | |
image_path = self.df.iloc[i][self.file_column] | |
sample = cv2.imread(image_path) | |
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB) | |
label = eval(self.df.iloc[i][self.label_column]) | |
return sample, label | |
def __len__(self): | |
return len(self.df) | |
def _collate_fn(self, batch): | |
imgs, classes = list(zip(*batch)) | |
if self.transforms: | |
imgs = [self.transforms(image=img)["image"][None] for img in imgs] | |
classes = [torch.tensor([clss]) for clss in classes] | |
imgs, classes = [torch.cat(i) for i in [imgs, classes]] | |
return imgs, classes | |
class CSVInferenceDataSet(Dataset): | |
def __init__(self, df, root_path, file_column, label_column, transforms=None): | |
super().__init__() | |
self.df = pd.read_csv(df) | |
self.transforms = transforms | |
if root_path: | |
self.root_path = Path(root_path) | |
else: | |
self.root_path = None | |
self.file_column = file_column | |
self.label_column = label_column | |
def __getitem__(self, i): | |
if self.root_path: | |
image_path = self.root_path / self.df.iloc[i][self.file_column] | |
else: | |
image_path = self.df.iloc[i][self.file_column] | |
sample = cv2.imread(image_path) | |
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB) | |
label = eval(self.df.iloc[i][self.label_column]) | |
return str(image_path), sample, label | |
def __len__(self): | |
return len(self.df) | |
def _collate_fn(self, batch): | |
filename, imgs, classes = list(zip(*batch)) | |
if self.transforms: | |
imgs = [self.transforms(image=img)["image"][None] for img in imgs] | |
classes = [torch.tensor([clss]) for clss in classes] | |
imgs, classes = [torch.cat(i) for i in [imgs, classes]] | |
filename = [str(f) for f in filename] | |
return filename, imgs, classes | |
def generate_pseudo_set(df, root_path, file_column, label_column, batch_size): | |
transform = A.Compose( | |
[ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2(), | |
] | |
) | |
test_dataset = CSVInferenceDataSet(df, root_path, file_column, label_column, transform) | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=test_dataset._collate_fn, | |
num_workers=multiprocessing.cpu_count() - 1, | |
pin_memory=True, | |
) | |
return test_loader | |
class CSVMineDataSet(Dataset): | |
def __init__(self, df, transforms=None, fast_dev_run=True): | |
super().__init__() | |
self.df = pd.read_csv(df) | |
if fast_dev_run: | |
self.df = self.df.head(10000) | |
self.transforms = transforms | |
def __getitem__(self, i): | |
image_path = self.df.iloc[i].image_key | |
sample = cv2.imread(image_path) | |
sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB) | |
return str(image_path), sample | |
def __len__(self): | |
return len(self.df) | |
def _collate_fn(self, batch): | |
filename, imgs = list(zip(*batch)) | |
if self.transforms: | |
imgs = [self.transforms(image=img)["image"][None] for img in imgs] | |
imgs = torch.cat(imgs) | |
filename = [str(f) for f in filename] | |
return filename, imgs | |
def generate_mine_set(filename, batch_size, fast_dev_run=1): | |
transform = A.Compose( | |
[ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2(), | |
] | |
) | |
fast_dev_run = True if fast_dev_run == 1 else False | |
mine_dataset = CSVMineDataSet(filename, transform, fast_dev_run) | |
mine_loader = DataLoader( | |
mine_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=mine_dataset._collate_fn, | |
num_workers=multiprocessing.cpu_count() - 1, | |
pin_memory=True, | |
) | |
return mine_loader | |
class TFFCDataModule(L.LightningDataModule): | |
def __init__(self, config, batch_size=None): | |
super().__init__() | |
self.config = config | |
if batch_size is None: | |
self.batch_size = self.config["batch_size"] | |
else: | |
self.batch_size = batch_size | |
def prepare_data(self): | |
pass | |
def setup(self, stage: str): | |
self.transform = A.Compose( | |
[ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2(), | |
] | |
) | |
self.train_transform = A.Compose( | |
[ | |
A.Resize(224, 224), | |
A.HorizontalFlip(p=0.5), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2(), | |
] | |
) | |
self.train_dataset = CSVDataSet( | |
self.config["train_df"], self.config["root_path"], self.config["file_column"], self.config["label_column"], self.train_transform | |
) | |
self.val_dataset = CSVDataSet( | |
self.config["val_df"], self.config["root_path"], self.config["file_column"], self.config["label_column"], self.transform | |
) | |
self.test_dataset = CSVDataSet( | |
self.config["test_df"], self.config["root_path"], self.config["file_column"], self.config["label_column"], self.transform | |
) | |
# multiprocessing.cpu_count() - 1 | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
collate_fn=self.train_dataset._collate_fn, | |
num_workers=multiprocessing.cpu_count() - 1, | |
pin_memory=True, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.batch_size, | |
collate_fn=self.val_dataset._collate_fn, | |
num_workers=multiprocessing.cpu_count() - 1, | |
pin_memory=True, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
collate_fn=self.test_dataset._collate_fn, | |
num_workers=multiprocessing.cpu_count() - 1, | |
pin_memory=True, | |
) | |
if __name__ == "__main__": | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import lightning.pytorch as L | |
from datetime import datetime | |
from torchvision import models | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
from torchmetrics.classification import MultilabelAccuracy | |
class Config: | |
CRITERION = nn.BCEWithLogitsLoss() | |
OPTIMIZER = torch.optim.AdamW | |
SCHEDULER = CosineAnnealingLR | |
class ObstructionDetectionModel(nn.Module): | |
def __init__(self, config): | |
super(ObstructionDetectionModel, self).__init__() | |
self.config = config | |
self.model = models.mobilenet_v2(weights="IMAGENET1K_V2") | |
self.model.classifier[1] = nn.Linear( | |
self.model.classifier[1].in_features, self.config["classes"] | |
) | |
def forward(self, inputs): | |
outputs = self.model(inputs) | |
return outputs | |
class LightningModel(L.LightningModule): | |
def __init__(self, config, learning_rate=0.00001): | |
super().__init__() | |
self.save_hyperparameters() | |
self.learning_rate = learning_rate | |
self.config = config | |
self.criterion = Config.CRITERION | |
self.optimizer = Config.OPTIMIZER | |
self.scheduler = Config.SCHEDULER | |
self.model = ObstructionDetectionModel(config) | |
self.model = torch.compile(self.model) | |
self.train_acc = MultilabelAccuracy(num_labels=config["classes"]) | |
self.val_acc = MultilabelAccuracy(num_labels=config["classes"]) | |
self.test_acc = MultilabelAccuracy(num_labels=config["classes"]) | |
def forward(self, x): | |
return self.model(x) | |
def _shared_step(self, batch): | |
sample, target = batch | |
logits = self(sample) | |
loss = self.criterion(logits, target.float()) | |
return loss, target, logits | |
def on_train_epoch_end(self): | |
_train_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
train_ts_date = int(_train_timestamp.split("_")[0]) | |
train_ts_time = int(_train_timestamp.split("_")[1]) | |
self.log("train_ts_date", train_ts_date) | |
self.log("train_ts_time", train_ts_time) | |
def on_validation_epoch_end(self): | |
_val_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
val_ts_date = int(_val_timestamp.split("_")[0]) | |
val_ts_time = int(_val_timestamp.split("_")[1]) | |
self.log("val_ts_date", val_ts_date) | |
self.log("val_ts_time", val_ts_time) | |
def training_step(self, batch, batch_idx): | |
loss, true_labels, logits = self._shared_step(batch) | |
self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False) | |
self.train_acc(logits, true_labels) | |
self.log( | |
"train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False | |
) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss, true_labels, logits = self._shared_step(batch) | |
self.log("val_loss", loss, prog_bar=True) | |
self.val_acc(logits, true_labels) | |
self.log("val_acc", self.val_acc, prog_bar=True) | |
def test_step(self, batch, batch_idx): | |
_, true_labels, logits = self._shared_step(batch) | |
self.test_acc(logits, true_labels) | |
self.log("test_acc", self.test_acc) | |
def configure_optimizers(self): | |
optimizer = self.optimizer(self.parameters(), lr=self.learning_rate) | |
scheduler = self.scheduler(optimizer, T_max=5) | |
return [optimizer], [scheduler] | |
if __name__ == "__main__": | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import torch | |
import mlflow | |
from watermark import watermark | |
from lightning.pytorch.cli import LightningCLI | |
from lightning.pytorch.callbacks import RichProgressBar | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from lightning.pytorch.callbacks import LearningRateFinder | |
from lightning.pytorch.callbacks import LearningRateMonitor | |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
from src.data import TFFCDataModule | |
from src.models import LightningModel | |
os.environ["LOGNAME"] = "krishnatheja.vanka" | |
torch.set_float32_matmul_precision("high") | |
class Config: | |
MODEL_NAME = "class3" | |
def cli_main(): | |
version_no = time.strftime("%Y%m%d_%H%M%S", time.gmtime(int(float(time.time())))) | |
model_name = Config.MODEL_NAME | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val_loss", | |
dirpath=f"artifacts/iter_{version_no}_{model_name}/", | |
filename="RC_{val_ts_date:.12g}_{val_ts_time:g}_ep{epoch:03d}", | |
mode="min", | |
save_top_k=1, | |
auto_insert_metric_name=False, | |
save_on_train_epoch_end=False, | |
) | |
mlflow.set_experiment(experiment_name="tffc-road-condition-multi-label") | |
mlflow.start_run(run_name=f"{version_no}") | |
mlflow.pytorch.autolog() | |
cli = LightningCLI( | |
LightningModel, | |
TFFCDataModule, | |
seed_everything_default=42, | |
run=False, | |
trainer_defaults={ | |
"fast_dev_run": False, | |
"callbacks": [ | |
checkpoint_callback, | |
RichProgressBar(), | |
LearningRateFinder(), | |
LearningRateMonitor(logging_interval="epoch"), | |
EarlyStopping(monitor="train_acc", mode="max"), | |
], | |
}, | |
) | |
cli.trainer.fit(cli.model, datamodule=cli.datamodule) | |
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) | |
mlflow.end_run() | |
if __name__ == "__main__": | |
# python train.py --config conf/masterconfig.yaml | |
print(watermark(packages="torch,lightning", python=True)) | |
print("Torch CUDA available?", torch.cuda.is_available()) | |
cli_main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment