Last active
March 25, 2021 03:13
-
-
Save czhu12/f76dc3ecbc9a8e71acb291bb80c5278a to your computer and use it in GitHub Desktop.
Base image classification model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.optim.lr_scheduler as lr_scheduler | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
import torch | |
from torch import optim | |
import torch.optim.lr_scheduler as lr_scheduler | |
from torchvision.datasets import CIFAR10, CIFAR100 | |
from torch.utils.data import DataLoader, random_split | |
from pytorch_lightning.metrics.functional import accuracy | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
from torchvision import transforms | |
import os | |
import pandas as pd | |
class BaseImageClassificationModel(pl.LightningModule): | |
def forward(self, x): | |
# in lightning, forward defines the prediction/inference actions | |
predictions = self.net(x) | |
return predictions | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
output = self.net(x) | |
loss = F.cross_entropy(output, y) | |
preds = torch.argmax(output, dim=1) | |
acc = accuracy(preds, y) | |
# Calling self.log will surface up scalars for you in TensorBoard | |
self.log('val_loss', loss, prog_bar=True) | |
self.log('val_acc', acc, prog_bar=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
# Here we just reuse the validation_step for testing | |
return self.validation_step(batch, batch_idx) | |
def configure_optimizers(self): | |
optimizer = optim.SGD(self.net.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=5e-4) | |
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) | |
return [optimizer]#, [scheduler] | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=16, shuffle=True) | |
def val_dataloader(self): | |
return DataLoader(self.test_dataset, batch_size=16) | |
def test_dataloader(self): | |
return DataLoader(self.val_dataset, batch_size=16) | |
class ChestXRayDataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
normal_directory = os.path.join(root_dir, "NORMAL") | |
pneumonia_directory = os.path.join(root_dir, "PNEUMONIA") | |
normal_images = [os.path.join(normal_directory, path) for path in os.listdir(normal_directory)] | |
pneumonia_images = [os.path.join(pneumonia_directory, path) for path in os.listdir(pneumonia_directory)] | |
negative_labels = [0] * len(normal_images) | |
positive_labels = [1] * len(normal_images) | |
self.df = pd.DataFrame(zip(normal_images + pneumonia_images, negative_labels + positive_labels)) | |
self.df.columns = ['image_path', 'label'] | |
self.transform = transform | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx): | |
row = self.df.iloc[idx] | |
image_path = row['image_path'] | |
img = Image.open(image_path).convert('RGB') | |
img = img.resize((224, 224)) | |
if self.transform: | |
img = self.transform(img) | |
torch.unsqueeze(img, 0) | |
return img, row['label'] | |
@staticmethod | |
def transform_train(): | |
return transforms.Compose([ | |
transforms.RandomCrop(224, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
@staticmethod | |
def transform_test(): | |
return transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment