Last active
January 14, 2022 11:41
-
-
Save hav4ik/2d80b63dcab69651ad7bf1495053e35c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import copy | |
import pycocotools.mask as mask_util | |
import numpy as np | |
class CopyPasteAugmentator: | |
"""Copy-paste cells from another image in the dataset | |
""" | |
def __init__(self, d2_dataset, | |
paste_same_class=True, | |
paste_density=[0.3, 0.6], | |
filter_area_thresh=0.1, | |
p=1.0): | |
self.data = d2_dataset | |
self.n_samples = len(d2_dataset) | |
self.paste_same_class = paste_same_class | |
if paste_same_class: | |
self.cls_indices = [ | |
[ | |
i for i, item in enumerate(d2_dataset) | |
if item['annotations'][0]['category_id'] == cls_index | |
] | |
for cls_index in range(3) | |
] | |
self.filter_area_thresh = filter_area_thresh | |
self.paste_density = paste_density | |
self.p = p | |
def __call__(self, dataset_dict): | |
# print(dataset_dict) | |
orig_img = cv2.imread(dataset_dict["file_name"]) | |
if 'LIVECell_dataset_2021' in dataset_dict["file_name"]: | |
return orig_img, dataset_dict | |
if np.random.uniform() < self.p: | |
# Choose a sample to copy-paste from | |
if self.paste_same_class: | |
cls_id = dataset_dict['annotations'][0]['category_id'] | |
random_idx = np.random.randint(0, len(self.cls_indices[cls_id])) | |
random_ds_dict = self.data[self.cls_indices[cls_id][random_idx]] | |
else: | |
random_idx = np.random.randint(0, self.n_samples) | |
random_ds_dict = self.data[random_idx] | |
# Load chosen sample | |
random_img = cv2.imread(random_ds_dict['file_name']) | |
if isinstance(self.paste_density, list): | |
paste_density = np.random.uniform(self.paste_density[0], self.paste_density[1]) | |
else: | |
paste_density = self.paste_density | |
# Selection indices | |
selected_cell_ids = np.random.choice( | |
len(random_ds_dict['annotations']), | |
size=round(paste_density * len(random_ds_dict['annotations'])), | |
replace=False) | |
# Select annotations (we deepcopy only selected ones, not the whole dict) | |
selected_annos = [copy.deepcopy(random_ds_dict['annotations'][i]) | |
for i in selected_cell_ids] | |
copypaste_mask = mask_util.decode(selected_annos[0]['segmentation']).astype(np.bool) | |
for anno in selected_annos[1:]: | |
copypaste_mask |= mask_util.decode(anno['segmentation']).astype(np.bool) | |
# Copy cells over | |
neg_mask = ~copypaste_mask | |
filtered_annos = [] | |
for anno in dataset_dict['annotations']: | |
mask = mask_util.decode(anno['segmentation']).astype(np.bool) | |
ocluded_mask = (mask & neg_mask) | |
if (round(self.filter_area_thresh * mask.sum()) < ocluded_mask.sum()): | |
anno['segmentation'] = mask_util.encode(np.asfortranarray(ocluded_mask)) | |
filtered_annos.append(anno) | |
# Form output | |
orig_img[copypaste_mask] = random_img[copypaste_mask] | |
dataset_dict['annotations'] = filtered_annos + selected_annos | |
return orig_img, dataset_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import albumentations as A | |
import copy | |
import numpy as np | |
from PIL import Image | |
import detectron2.data.transforms as T | |
import torch | |
from detectron2.data import detection_utils as utils | |
from detectron2.data import build_detection_test_loader, build_detection_train_loader | |
from detectron2.data import detection_utils as utils | |
class CustomDatasetMapper: | |
def __init__(self, cfg): | |
self.copypaste_augmentator = CopyPasteAugmentator( | |
train_ds, | |
paste_same_class=True, | |
paste_density=cfg.CUSTOM_MAPPER.PASTE_DENSITY, | |
filter_area_thresh=0.1, | |
p=cfg.CUSTOM_MAPPER.COPYPASTE_PROB, | |
) | |
self.min_size_train = cfg.INPUT.MIN_SIZE_TRAIN | |
self.max_size_train = cfg.INPUT.MAX_SIZE_TRAIN | |
# See "Data Augmentation" tutorial for details usage | |
self.augs = T.AugmentationList([ | |
T.ResizeShortestEdge( | |
short_edge_length=self.min_size_train, | |
max_size=self.max_size_train, | |
sample_style='choice', interp=Image.BICUBIC), | |
# T.RandomCrop("relative", (.8, .8)), | |
T.RandomFlip(prob=0.5, vertical=True, horizontal=False), | |
T.RandomFlip(prob=0.5, vertical=False, horizontal=True), | |
# T.RandomRotation(angle=[-5, 5], sample_style='range', expand=False, p=0.2), | |
]) # type: T.Augmentation | |
# Non-geometric transformations | |
self.albu_transform = A.Compose([ | |
# A.CLAHE(p=0.3), | |
# A.Blur(p=0.3), | |
# A.MedianBlur(p=0.3), | |
# A.MotionBlur(p=0.3), | |
# A.RandomBrightnessContrast(p=0.3), | |
]) | |
# Show how to implement a minimal mapper, similar to the default DatasetMapper | |
def __call__(self, dataset_dict): | |
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below | |
# can use other ways to read image | |
image, dataset_dict = self.copypaste_augmentator(dataset_dict) # already have deepcopy inside | |
auginput = T.AugInput(image) | |
transform = self.augs(auginput) | |
image = auginput.image | |
image = self.albu_transform(image=image)['image'] | |
image_shape = image.shape[:2] # h, w | |
annos = [ | |
utils.transform_instance_annotations(annotation, [transform], image_shape) | |
for annotation in dataset_dict.pop("annotations") | |
] | |
dataset_dict['image'] = torch.as_tensor(image.transpose(2, 0, 1).astype(np.float32)) | |
dataset_dict['instances'] = utils.annotations_to_instances( | |
annos, image_shape, mask_format="bitmask") | |
return dataset_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment