Skip to content

Instantly share code, notes, and snippets.

@Z-Zheng
Created May 13, 2019 03:23
Show Gist options
  • Save Z-Zheng/3ad65a26f5bac95b83f0f261d89af6db to your computer and use it in GitHub Desktop.
Save Z-Zheng/3ad65a26f5bac95b83f0f261d89af6db to your computer and use it in GitHub Desktop.
dataset class and transform class for lxy
from torch.utils.data import Dataset
import glob
import os
from skimage.io import imread
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from simplecv.util import tensor_util
from simplecv.interface import CVModule
from simplecv.data import preprocess
import torch
import torch.nn.functional as F
pallete = [
220, 20, 60,
128, 64, 128,
70, 70, 70,
102, 102, 156,
190, 153, 153,
153, 153, 153,
250, 170, 30,
220, 220, 0,
107, 142, 35,
0, 130, 180,
255, 0, 0,
0, 0, 142,
0, 0, 70,
0, 60, 100,
244, 35, 232,
0, 80, 100,
119, 11, 32,
0, 0, 230,
152, 251, 152,
]
def get_color_pallete(npimg):
"""Visualize image.
Parameters
----------
npimg : numpy.ndarray
Single channel image with shape `H, W, 1`.
Returns
-------
out_img : PIL.Image
Image with color pallete
"""
# recovery boundary
# if dataset in ('pascal_voc', 'pascal_aug'):
# npimg[npimg == -1] = 255
# put colormap
out_img = Image.fromarray(npimg.astype('uint8'))
out_img.putpalette(pallete)
return out_img
def plot_mask(img, masks, alpha=0.5):
"""Visualize segmentation mask.
Parameters
----------
img : numpy.ndarray
Image with shape `H, W, 3`.
masks : numpy.ndarray
Binary images with shape `N, H, W`.
alpha : float, optional, default 0.5
Transparency of plotted mask
Returns
-------
numpy.ndarray
The image plotted with segmentation masks
"""
rs = np.random.RandomState(567)
for mask in masks:
color = rs.random_sample(3) * 255
mask = np.repeat((mask > 0)[:, :, np.newaxis], repeats=3, axis=2)
img = np.where(mask, img * (1 - alpha) + color * alpha, img)
return img.astype('uint8')
class DeepglobeRoad(Dataset):
def __init__(self, root, training=True, transforms=None):
self.root = root
self.transforms = transforms
self.training = training
self.im_path_list = glob.glob(os.path.join(root, '*_sat.jpg'))
self.mask_path_list = [im_path.replace('sat.jpg', 'mask.png') for im_path in self.im_path_list]
def __getitem__(self, idx):
image_np = imread(self.im_path_list[idx])
mask_np = imread(self.mask_path_list[idx])
if self.transforms is not None:
image_tensor, mask_tensor = self.transforms(image_np, mask_np)
else:
image_tensor, mask_tensor = tensor_util.to_tensor([image_np, mask_np])
return dict(rgb=image_tensor,
image_filename=os.path.basename(self.im_path_list[idx])), dict(cls=mask_tensor)
def __len__(self):
return len(self.im_path_list)
def show_image_mask(self, index, with_mask=True):
x, y = self[index]
# denormalize
_mean = torch.tensor(self.transforms.config.mean_std_normalize.mean).reshape(3, 1, 1)
_std = torch.tensor(self.transforms.config.mean_std_normalize.std).reshape(3, 1, 1)
# to uint8
x['rgb'] = x['rgb'].mul_(_std).add_(_mean).byte()
# to np
image_np = x['rgb'].permute((1, 2, 0)).numpy()
if with_mask:
mask_np = y['cls'].numpy()
color_mask = np.asarray(get_color_pallete(mask_np))
vis_image = plot_mask(image_np, color_mask.reshape([1] + list(color_mask.shape)))
else:
vis_image = image_np
plt.imshow(vis_image)
class DeepglobeRoadTransform(CVModule):
def __init__(self, config):
super(DeepglobeRoadTransform, self).__init__(config)
pass
def forward(self, images, masks):
"""
Args:
images: 3-D array of shape [height, width, channel]
masks: 2-D array of shape [height, width]
Returns:
images_tensor: 3-D float32 tensor of shape [channel, height, width]
masks_tensor: 2-D int64 tensor of shape [height, width]
"""
assert images.ndim == 3
assert masks.ndim == 2
images = images.astype(np.float32)
masks = masks.astype(np.float32)
images_tensor, masks_tensor = tensor_util.to_tensor([images, masks])
images_tensor = preprocess.mean_std_normalize(images_tensor,
self.config.mean_std_normalize.mean,
self.config.mean_std_normalize.std)
if self.config.training:
for trans_op in self.config.transforms:
images_tensor, masks_tensor = trans_op(images_tensor, masks_tensor)
images_tensor = images_tensor.permute((2, 0, 1))
masks_tensor = masks_tensor.long()
return images_tensor, masks_tensor
def set_defalut_config(self):
self.config.update(dict(
training=True,
mean_std_normalize=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375)
),
transforms=list([
THRandomRotate90k(p=0.5),
THRandomHorizontalFlip(p=0.5),
THRandomVerticalFlip(p=0.5),
THRandomScale(),
THRandomCrop((512, 512))
])
))
################################################## works for lxy #############################################
# example for lxy
class THRandomRotate90k(object):
def __init__(self, p=0.5, k=None):
self.p = p
self.k = k
def __call__(self, images, masks=None):
""" Rotate 90 * k degree for image and mask
Args:
images: 3-D tensor of shape [height, width, channel]
masks: 2-D tensor of shape [height, width]
Returns:
images_tensor
masks_tensor
"""
k = int(np.random.choice([1, 2, 3], 1)[0]) if self.k is None else self.k
ret = list()
images_tensor = torch.rot90(images, k, [0, 1])
ret.append(images_tensor)
if masks is not None:
masks_tensor = torch.rot90(masks, k, [0, 1])
ret.append(masks_tensor)
return ret if len(ret) > 1 else ret[0]
class THRandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, images, masks=None):
"""
Args:
images: 3-D tensor of shape [height, width, channel]
masks: 2-D tensor of shape [height, width]
Returns:
images_tensor
masks_tensor
"""
ret = list()
if self.p < np.random.uniform():
ret.append(images)
if masks is not None:
ret.append(masks)
return ret if len(ret) > 1 else ret[0]
images_tensor = torch.flip(images, [1])
ret.append(images_tensor)
if masks is not None:
masks_tensor = torch.flip(masks, [1])
ret.append(masks_tensor)
return ret if len(ret) > 1 else ret[0]
class THRandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, images, masks=None):
"""
Args:
images: 3-D tensor of shape [height, width, channel]
masks: 2-D tensor of shape [height, width]
Returns:
images_tensor
masks_tensor
"""
ret = list()
if self.p < np.random.uniform():
ret.append(images)
if masks is not None:
ret.append(masks)
return ret if len(ret) > 1 else ret[0]
images_tensor = torch.flip(images, [0])
ret.append(images_tensor)
if masks is not None:
masks_tensor = torch.flip(masks, [0])
ret.append(masks_tensor)
return ret if len(ret) > 1 else ret[0]
class THRandomCrop(object):
def __init__(self, crop_size=(512, 512)):
self.crop_size = crop_size
def __call__(self, images, masks=None):
"""
Args:
images: 3-D tensor of shape [height, width, channel]
masks: 2-D tensor of shape [height, width]
Returns:
images_tensor
masks_tensor
"""
im_h, im_w, _ = images.shape
c_h, c_w = self.crop_size
pad_h = c_h - im_h
pad_w = c_w - im_w
if pad_h > 0 or pad_w > 0:
images = F.pad(images, [0, 0, 0, max(pad_w, 0), 0, max(pad_h, 0)], mode='constant', value=0)
masks = F.pad(masks, [0, max(pad_w, 0), 0, max(pad_h, 0)], mode='constant', value=0)
im_h, im_w, _ = images.shape
y_lim = im_h - c_h + 1
x_lim = im_w - c_w + 1
ymin = int(np.random.randint(0, y_lim, 1))
xmin = int(np.random.randint(0, x_lim, 1))
xmax = xmin + c_w
ymax = ymin + c_h
ret = list()
images_tensor = images[ymin:ymax, xmin:xmax, :]
ret.append(images_tensor)
if masks is not None:
masks_tensor = masks[ymin:ymax, xmin:xmax]
ret.append(masks_tensor)
return ret
class THRandomScale(object):
def __init__(self, scale_range=(0.5, 2.0), scale_step=0.25):
scale_factors = np.linspace(scale_range[0], scale_range[1],
int((scale_range[1] - scale_range[0]) / scale_step) + 1)
self.scale_factor = np.random.choice(scale_factors, size=1)[0]
def __call__(self, images, masks=None):
"""
Args:
images: 3-D tensor of shape [height, width, channel]
masks: 2-D tensor of shape [height, width]
Returns:
images_tensor
masks_tensor
"""
ret = list()
_images = images.permute(2, 0, 1)[None, :, :, :]
images_tensor = F.interpolate(_images, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
images_tensor = images_tensor[0].permute(1, 2, 0)
ret.append(images_tensor)
if masks is not None:
masks_tensor = F.interpolate(masks[None, None, :, :], scale_factor=self.scale_factor, mode='nearest')[0][0]
ret.append(masks_tensor)
return ret
if __name__ == '__main__':
root = r'D:\deepglobe\road_data\train_0_1'
dr = DeepglobeRoad(root, transforms=DeepglobeRoadTransform({}))
dr.show_image_mask(0)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment