Skip to content

Instantly share code, notes, and snippets.

@stefanherdy
Last active October 4, 2023 09:57
Show Gist options
  • Save stefanherdy/2323eff239c8626e39b232506c994368 to your computer and use it in GitHub Desktop.
Save stefanherdy/2323eff239c8626e39b232506c994368 to your computer and use it in GitHub Desktop.
Pytorch data augmentation script for semantic image segmentation. For further details please have a look at my story on Medium: https://medium.com/@stefan.herdy/how-to-augment-images-for-semantic-segmentation-2d7df97544de . A full semantic segmentation project can be found here: https://github.com/stefanherdy/pytorch-semantic-segmentation
import torch
from skimage.io import imread
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
class RandomFlip:
def __init__(self):
pass
def __call__(self, inp: np.ndarray, tar: np.ndarray):
# Randomly flip vertically (50/50 chance)
rand = random.choice([0, 1])
if rand == 1:
inp = np.moveaxis(inp, 0, -1)
inp = cv2.flip(inp, 1)
inp = np.moveaxis(inp, -1, 0)
tar = np.ndarray.copy(np.fliplr(tar))
# Randomly flip horicontally (50/50 chance)
rand = random.choice([0, 1])
if rand == 1:
inp = np.moveaxis(inp, 0, -1)
inp = cv2.flip(inp, 0)
inp = np.moveaxis(inp, -1, 0)
tar = np.ndarray.copy(np.flipud(tar))
return inp, tar
class RandomCropTrain:
def __init__(self):
pass
def __call__(self, inp: np.ndarray, tar: np.ndarray):
# Specify crop width and height
crop_width = 1900
crop_height =1900
max_x = inp.shape[1] - crop_width
max_y = inp.shape[2] - crop_height
# Generate random crop values
x = np.random.randint(0, max_x)
y = np.random.randint(0, max_y)
# Crop
inp = inp[x: x + crop_width, y: y + crop_height,:]
tar = tar[x: x + crop_width, y: y + crop_height]
return inp, tar
class DataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None,
):
self.inputs = inputs
self.targets = targets
self.transform = transform
self.inputs_dtype = torch.float32
self.targets_dtype = torch.int
def __len__(self):
return len(self.inputs)
def __getitem__(self,
idx: int):
# image and target dir
input_ID = self.inputs[idx]
target_ID = self.targets[idx]
# Load input and target
x, y = imread(input_ID), imread(target_ID)
# Preprocessing
if self.transform is not None:
x, y = self.transform(x, y)
x, y = torch.from_numpy(x.copy()).type(self.inputs_dtype), torch.from_numpy(y.copy()).type(self.targets_dtype)
return x, y
class Compose:
"""
Composes several transforms together.
"""
def __init__(self, transforms: list):
self.transforms = transforms
def __call__(self, input, target):
for tr in self.transforms:
input, target = tr(input, target)
return input, target
def __repr__(self): return str([transform for transform in self.transforms])
from torch.utils.data import DataLoader
transforms_train = Compose([
RandomFlip(),
RandomCrop()
])
# train dataset
dataset_train = DataSet(inputs=inputs,
targets=targets,
transform=transforms_train)
# train dataloader
dataloader_training = DataLoader(dataset=dataset_train,
batch_size=batchsize,
shuffle=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment