Skip to content

Instantly share code, notes, and snippets.

@talhaanwarch
Last active September 7, 2021 19:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save talhaanwarch/13ffc9f14043ab7933899f41a8996bb5 to your computer and use it in GitHub Desktop.
Save talhaanwarch/13ffc9f14043ab7933899f41a8996bb5 to your computer and use it in GitHub Desktop.
PL segmentation gist
# data link
#https://drive.google.com/file/d/1EwjJx-V-Gq7NZtfiT6LZPLGXD2HN--qT/view?usp=sharing
from glob import glob
import cv2
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
path='D:/image/classification/2D/PSL/data/segmentation/eyth_dataset/'
# In[3]:
# # data issue, the two mask folders images has no ext with it
# for i in glob(path+'masks/vid4/**'):
# os.rename(i, i+'.png')
# for i in glob(path+'masks/vid9/**'):
# os.rename(i, i+'.png')
# In[4]:
def get_path():
images=sorted([glob(i+'*.jpg') for i in glob(path+'/images/*/')])
images = sorted([item for sublist in images for item in sublist])
masks=sorted([glob(i+'*.png') for i in glob(path+'/masks/*/')])
masks = sorted([item for sublist in masks for item in sublist])
return images,masks
images,masks=get_path()
# In[5]:
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(images, masks)
]
print(len(data_dicts))
train_files, val_files=train_test_split(data_dicts,test_size=0.2,random_state=21)
len(train_files),len(val_files)
# In[6]:
import cv2
image=cv2.imread(train_files[0]['image'])
image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
mask=cv2.imread(train_files[0]['label'],0)
print(image.shape)
print(mask.max(),mask.min())
fig,ax=plt.subplots(1,2)
ax[0].imshow(image)
ax[1].imshow(mask,cmap='gray')
# In[7]:
from pytorch_lightning import seed_everything, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor
from torch.utils.data import DataLoader,Dataset
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau,CosineAnnealingWarmRestarts
import torch.nn as nn
import torch
import torchvision
from torch.nn import functional as F
# In[8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
train_aug= A.Compose([
A.Resize(224,224),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=(0), std=(1)),
ToTensorV2(p=1.0),
], p=1.0)
val_aug= A.Compose([
A.Resize(224, 224),
A.Normalize(mean=(0), std=(1)),
ToTensorV2(p=1.0),
], p=1.0)
# In[9]:
class DataReader(Dataset):
def __init__(self,data,transform=None):
super(DataReader,self).__init__()
self.data=data
self.transform=transform
def __len__(self):
return len(self.data)
def __getitem__(self,index):
image_path=self.data[index]['image']
mask_path=self.data[index]['label']
image=cv2.imread(image_path)
mask=cv2.imread(mask_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB )
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY )
if self.transform:
transformed =self.transform(image=image,mask=mask)
image=transformed['image']
mask=transformed['mask']
mask=np.expand_dims(mask,0)/255
return image,mask
# In[10]:
ds = DataReader(data=data_dicts, transform=train_aug)
loader=DataLoader(ds, batch_size=8, shuffle=True,num_workers=0)
batch= next(iter(loader))
print(batch[0].shape,batch[1].shape)
print(batch[1].max(),batch[1].min())
# In[11]:
plt.figure()
grid_img=torchvision.utils.make_grid(batch[0],4,4)
plt.imshow(grid_img.permute(1, 2, 0))
plt.title('batch of images')
plt.figure()
grid_img=torchvision.utils.make_grid(batch[1],4,4)
plt.imshow(grid_img.permute(1, 2, 0)*255)
plt.title('batch of masks')
# In[12]:
from einops import rearrange
def dice_coef(mask_pred,mask_gt ):
def compute_dice_coefficient(mask_pred,mask_gt, smooth = 0.0001):
"""Compute soerensen-dice coefficient.
compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
and the predicted mask `mask_pred`.
Args:
mask_gt: 4-dim Numpy array of type bool. The ground truth mask. [B, 1, H, W]
mask_pred: 4-dim Numpy array of type bool. The predicted mask. [B, C, H, W]
Returns:
the dice coeffcient as float. If both masks are empty, the result is NaN
"""
volume_sum = mask_gt.sum() + mask_pred.sum()
volume_intersect = (mask_gt * mask_pred).sum()
return (2*volume_intersect+smooth) / (volume_sum+smooth)
dice=0
n_pred_ch = mask_pred.shape[1]
mask_pred=torch.softmax(mask_pred, 1)
mask_gt=F.one_hot(mask_gt.long(), num_classes=n_pred_ch) #create one hot vector
mask_gt=rearrange(mask_gt, 'd0 d1 d2 d3 d4 -> d0 (d1 d4) d2 d3 ') #reshape one hot vector
for ind in range(0,n_pred_ch):
dice += compute_dice_coefficient(mask_gt[:,ind,:,:], mask_pred[:,ind,:,:])
return dice/n_pred_ch # taking average
# In[13]:
from einops import rearrange
def dice_loss(pred,true, softmax=True,sigmoid=False,one_hot=True,background=True,smooth = 0.0001):
"""
pred: predicted values without applying any activation at the end
shape (B,C,H,W) example: (4, 59, 512, 512)
true: ground truth shape (B,1,H,W) example: (4, 1, 512, 512)
softmax: for multiclass
sigmoid: for binaryclass
one_hot: convert true values to one hot encoded
background: calculate background
"""
n_pred_ch = pred.shape[1]
if softmax:
assert n_pred_ch!=1, 'single channel found'
pred=torch.softmax(pred, 1)
if sigmoid:
pred=torch.sigmoid(pred, 1)
if one_hot:
assert n_pred_ch!=1, 'single channel found'
true=F.one_hot(true.long(), num_classes=n_pred_ch)
true=rearrange(true, 'd0 d1 d2 d3 d4 -> d0 (d1 d4) d2 d3 ')
if background is False:
assert one_hot!=True, 'apply one hot encode '
true = true[:, 1:]
pred = pred[:, 1:]
reduce_axis=torch.arange(1, len(true.shape)).tolist()# reducing only spatial dimensions (not batch nor channels)
intersection = torch.sum(true * pred, dim=reduce_axis)
denominator = torch.sum(true, dim=reduce_axis) + torch.sum(pred, dim=reduce_axis)
dice= (2.0 * intersection + smooth) / (denominator + smooth)
return 1.0 - torch.mean(dice) # the batch and channel average
# In[14]:
from einops import rearrange
def focal_dice_loss(pred,true,softmax=True,alpha=0.5,gamma=2):
"""
pred: predicted values without applying any activation at the end
shape (B,C,H,W) example: (4, 59, 512, 512)
true: ground truth shape (B,1,H,W) example: (4, 1, 512, 512)
"""
n_pred_ch = pred.shape[1]
if softmax:
assert n_pred_ch!=1, 'single channel found'
pred=torch.softmax(pred, 1)
celoss= F.cross_entropy(pred, torch.squeeze(true, dim=1).long(),reduction='none')
celoss=torch.exp(-celoss)
focal_loss = alpha * (1-celoss)**gamma * celoss
focal_loss=torch.mean(focal_loss)
diceloss=dice_loss(pred,true,softmax=False)#softmax false, beacuase already applied
return 0.5*focal_loss+0.5*diceloss
# In[15]:
import segmentation_models_pytorch as smp
import torchmetrics
# iou(preds_array,labels_array.type(torch.int).to('cuda'))
class OurModel(LightningModule):
def __init__(self):
super(OurModel,self).__init__()
#architecute
self.layer = smp.Unet(
encoder_name='resnet18',
encoder_weights='imagenet',
in_channels=3,
classes=2,
)
#parameters
self.lr=1e-3
self.batch_size=32
self.numworker=0
self.iou=torchmetrics.IoU(2)
def forward(self,x):
return self.layer(x)
def configure_optimizers(self):
opt=torch.optim.AdamW(self.parameters(), lr=self.lr,weight_decay=1e-5)
scheduler = CosineAnnealingWarmRestarts(opt,T_0=10, T_mult=1, eta_min=1e-5, last_epoch=-1)
return {'optimizer': opt,'lr_scheduler':scheduler}
def train_dataloader(self):
ds = DataReader(data=train_files, transform=train_aug)
loader=DataLoader(ds, batch_size=self.batch_size, shuffle=True,num_workers=self.numworker)
return loader
def training_step(self,batch,batch_idx):
image,segment=batch[0], batch[1]
out=self(image)
loss=focal_dice_loss(out,segment)
dice=dice_coef(out,segment)
iouscore=self.iou(out,segment.type(torch.int8))
return {'loss':loss,'iou':iouscore,'dice':dice}
def training_epoch_end(self, outputs):
loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
iou=torch.stack([x["iou"] for x in outputs]).mean().detach().cpu().numpy().round(2)
dice=torch.stack([x["dice"] for x in outputs]).mean().detach().cpu().numpy().round(2)
print('training loss, iou, dice ',loss, iou, dice)
def val_dataloader(self):
ds = DataReader(data=val_files, transform=val_aug)
loader=DataLoader(ds, batch_size=self.batch_size, shuffle=False,num_workers=self.numworker)
return loader
def validation_step(self,batch,batch_idx):
image,segment=batch[0], batch[1]
out=self(image)
loss=focal_dice_loss(out,segment)
dice=dice_coef(out,segment)
iouscore=self.iou(out,segment.type(torch.int8))
return {'loss':loss,'iou':iouscore,'dice':dice}
def validation_epoch_end(self, outputs):
loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
iou=torch.stack([x["iou"] for x in outputs]).mean().detach().cpu().numpy().round(2)
dice=torch.stack([x["dice"] for x in outputs]).mean().detach().cpu().numpy().round(2)
print('validation loss, iou, dice ',loss, iou, dice)
# In[16]:
model = OurModel()
# In[17]:
lr_monitor = LearningRateMonitor(logging_interval='epoch')
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath='unet',
filename='checkpoint')
trainer = Trainer(max_epochs=5,
gpus=-1,precision=16,
stochastic_weight_avg=True,
)
# In[18]:
trainer.fit(model)
# In[ ]:
trainer.validate(model)
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment