Skip to content

Instantly share code, notes, and snippets.

@czhu12
Last active March 25, 2021 03:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save czhu12/f76dc3ecbc9a8e71acb291bb80c5278a to your computer and use it in GitHub Desktop.
Save czhu12/f76dc3ecbc9a8e71acb291bb80c5278a to your computer and use it in GitHub Desktop.
Base image classification model
import torch.optim.lr_scheduler as lr_scheduler
import pytorch_lightning as pl
import torch.nn.functional as F
import torch
from torch import optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import DataLoader, random_split
from pytorch_lightning.metrics.functional import accuracy
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import os
import pandas as pd
class BaseImageClassificationModel(pl.LightningModule):
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
predictions = self.net(x)
return predictions
def validation_step(self, batch, batch_idx):
x, y = batch
output = self.net(x)
loss = F.cross_entropy(output, y)
preds = torch.argmax(output, dim=1)
acc = accuracy(preds, y)
# Calling self.log will surface up scalars for you in TensorBoard
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = optim.SGD(self.net.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=5e-4)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
return [optimizer]#, [scheduler]
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=16, shuffle=True)
def val_dataloader(self):
return DataLoader(self.test_dataset, batch_size=16)
def test_dataloader(self):
return DataLoader(self.val_dataset, batch_size=16)
class ChestXRayDataset(Dataset):
def __init__(self, root_dir, transform=None):
normal_directory = os.path.join(root_dir, "NORMAL")
pneumonia_directory = os.path.join(root_dir, "PNEUMONIA")
normal_images = [os.path.join(normal_directory, path) for path in os.listdir(normal_directory)]
pneumonia_images = [os.path.join(pneumonia_directory, path) for path in os.listdir(pneumonia_directory)]
negative_labels = [0] * len(normal_images)
positive_labels = [1] * len(normal_images)
self.df = pd.DataFrame(zip(normal_images + pneumonia_images, negative_labels + positive_labels))
self.df.columns = ['image_path', 'label']
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
image_path = row['image_path']
img = Image.open(image_path).convert('RGB')
img = img.resize((224, 224))
if self.transform:
img = self.transform(img)
torch.unsqueeze(img, 0)
return img, row['label']
@staticmethod
def transform_train():
return transforms.Compose([
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
@staticmethod
def transform_test():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment