Skip to content

Instantly share code, notes, and snippets.

@hav4ik
Last active January 14, 2022 11:41
Show Gist options
  • Save hav4ik/2d80b63dcab69651ad7bf1495053e35c to your computer and use it in GitHub Desktop.
Save hav4ik/2d80b63dcab69651ad7bf1495053e35c to your computer and use it in GitHub Desktop.
import cv2
import copy
import pycocotools.mask as mask_util
import numpy as np
class CopyPasteAugmentator:
"""Copy-paste cells from another image in the dataset
"""
def __init__(self, d2_dataset,
paste_same_class=True,
paste_density=[0.3, 0.6],
filter_area_thresh=0.1,
p=1.0):
self.data = d2_dataset
self.n_samples = len(d2_dataset)
self.paste_same_class = paste_same_class
if paste_same_class:
self.cls_indices = [
[
i for i, item in enumerate(d2_dataset)
if item['annotations'][0]['category_id'] == cls_index
]
for cls_index in range(3)
]
self.filter_area_thresh = filter_area_thresh
self.paste_density = paste_density
self.p = p
def __call__(self, dataset_dict):
# print(dataset_dict)
orig_img = cv2.imread(dataset_dict["file_name"])
if 'LIVECell_dataset_2021' in dataset_dict["file_name"]:
return orig_img, dataset_dict
if np.random.uniform() < self.p:
# Choose a sample to copy-paste from
if self.paste_same_class:
cls_id = dataset_dict['annotations'][0]['category_id']
random_idx = np.random.randint(0, len(self.cls_indices[cls_id]))
random_ds_dict = self.data[self.cls_indices[cls_id][random_idx]]
else:
random_idx = np.random.randint(0, self.n_samples)
random_ds_dict = self.data[random_idx]
# Load chosen sample
random_img = cv2.imread(random_ds_dict['file_name'])
if isinstance(self.paste_density, list):
paste_density = np.random.uniform(self.paste_density[0], self.paste_density[1])
else:
paste_density = self.paste_density
# Selection indices
selected_cell_ids = np.random.choice(
len(random_ds_dict['annotations']),
size=round(paste_density * len(random_ds_dict['annotations'])),
replace=False)
# Select annotations (we deepcopy only selected ones, not the whole dict)
selected_annos = [copy.deepcopy(random_ds_dict['annotations'][i])
for i in selected_cell_ids]
copypaste_mask = mask_util.decode(selected_annos[0]['segmentation']).astype(np.bool)
for anno in selected_annos[1:]:
copypaste_mask |= mask_util.decode(anno['segmentation']).astype(np.bool)
# Copy cells over
neg_mask = ~copypaste_mask
filtered_annos = []
for anno in dataset_dict['annotations']:
mask = mask_util.decode(anno['segmentation']).astype(np.bool)
ocluded_mask = (mask & neg_mask)
if (round(self.filter_area_thresh * mask.sum()) < ocluded_mask.sum()):
anno['segmentation'] = mask_util.encode(np.asfortranarray(ocluded_mask))
filtered_annos.append(anno)
# Form output
orig_img[copypaste_mask] = random_img[copypaste_mask]
dataset_dict['annotations'] = filtered_annos + selected_annos
return orig_img, dataset_dict
import albumentations as A
import copy
import numpy as np
from PIL import Image
import detectron2.data.transforms as T
import torch
from detectron2.data import detection_utils as utils
from detectron2.data import build_detection_test_loader, build_detection_train_loader
from detectron2.data import detection_utils as utils
class CustomDatasetMapper:
def __init__(self, cfg):
self.copypaste_augmentator = CopyPasteAugmentator(
train_ds,
paste_same_class=True,
paste_density=cfg.CUSTOM_MAPPER.PASTE_DENSITY,
filter_area_thresh=0.1,
p=cfg.CUSTOM_MAPPER.COPYPASTE_PROB,
)
self.min_size_train = cfg.INPUT.MIN_SIZE_TRAIN
self.max_size_train = cfg.INPUT.MAX_SIZE_TRAIN
# See "Data Augmentation" tutorial for details usage
self.augs = T.AugmentationList([
T.ResizeShortestEdge(
short_edge_length=self.min_size_train,
max_size=self.max_size_train,
sample_style='choice', interp=Image.BICUBIC),
# T.RandomCrop("relative", (.8, .8)),
T.RandomFlip(prob=0.5, vertical=True, horizontal=False),
T.RandomFlip(prob=0.5, vertical=False, horizontal=True),
# T.RandomRotation(angle=[-5, 5], sample_style='range', expand=False, p=0.2),
]) # type: T.Augmentation
# Non-geometric transformations
self.albu_transform = A.Compose([
# A.CLAHE(p=0.3),
# A.Blur(p=0.3),
# A.MedianBlur(p=0.3),
# A.MotionBlur(p=0.3),
# A.RandomBrightnessContrast(p=0.3),
])
# Show how to implement a minimal mapper, similar to the default DatasetMapper
def __call__(self, dataset_dict):
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# can use other ways to read image
image, dataset_dict = self.copypaste_augmentator(dataset_dict) # already have deepcopy inside
auginput = T.AugInput(image)
transform = self.augs(auginput)
image = auginput.image
image = self.albu_transform(image=image)['image']
image_shape = image.shape[:2] # h, w
annos = [
utils.transform_instance_annotations(annotation, [transform], image_shape)
for annotation in dataset_dict.pop("annotations")
]
dataset_dict['image'] = torch.as_tensor(image.transpose(2, 0, 1).astype(np.float32))
dataset_dict['instances'] = utils.annotations_to_instances(
annos, image_shape, mask_format="bitmask")
return dataset_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment