Skip to content

Instantly share code, notes, and snippets.

@ak64th
Created September 29, 2021 04:03
Show Gist options
  • Save ak64th/3569b592ba9855a99e4402c5c719de8b to your computer and use it in GitHub Desktop.
Save ak64th/3569b592ba9855a99e4402c5c719de8b to your computer and use it in GitHub Desktop.
Use imgaug with detectron2
import argparse
import copy
import os
import pathlib
from datetime import datetime
import imgaug.augmenters as iaa
import imgaug.random as ia_random
import numpy as np
import torch
from detectron2.config import get_cfg
from detectron2.data import build_detection_train_loader, build_detection_test_loader
from detectron2.data.detection_utils import (
read_image,
check_image_size,
annotations_to_instances,
filter_empty_instances,
)
from detectron2.engine import DefaultTrainer
from detectron2.engine import launch
from detectron2.evaluation import COCOEvaluator
from detectron2.structures import BoxMode
from detectron2.utils.logger import setup_logger
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
from imgaug.augmentables.polys import Polygon, PolygonsOnImage
class Mapper:
def __init__(
self,
augmenter: iaa.Augmenter,
is_train: bool = True,
image_format='BGR',
):
self.augmenter = augmenter
self.is_train = is_train
self.image_format = image_format
def __call__(self, dataset_dict):
dataset_dict = copy.deepcopy(dataset_dict)
dataset_dict.pop('sem_seg_file_name', None) # no need
image = read_image(dataset_dict["file_name"], format=self.image_format)
check_image_size(dataset_dict, image)
# FIXME: resize the image?
deterministic = self.augmenter.to_deterministic()
augmented_image = deterministic.augment_image(image)
dataset_dict['image'] = torch.as_tensor(
np.ascontiguousarray(augmented_image.transpose(2, 0, 1))
) # use torch.Tensor for efficiency
if not self.is_train or 'annotations' not in dataset_dict:
dataset_dict.pop('annotations', None)
return dataset_dict
for anno in dataset_dict['annotations']:
anno.pop('keypoints', None) # no need
# transform bounding boxes and segmentations
annos = [
_transform_annotation(obj, image.shape, deterministic)
for obj in dataset_dict.pop("annotations")
if obj.get('iscrowd', 0) == 0
]
# build the Instances structure
instances = annotations_to_instances(
annos, augmented_image.shape, mask_format='polygon'
)
dataset_dict["instances"] = filter_empty_instances(instances)
return dataset_dict
def _transform_annotation(annotation, image_shape, augmentation: iaa.Augmenter):
assert augmentation.deterministic, 'Augmenter instance not deterministic.'
# transform the bounding box
bbox = BoxMode.convert(annotation['bbox'], annotation['bbox_mode'], BoxMode.XYXY_ABS)
_bbox = augmentation.augment_bounding_boxes(
BoundingBoxesOnImage([BoundingBox(*bbox)], shape=image_shape)
).remove_out_of_image().clip_out_of_image().bounding_boxes[0]
augmented_bbox = [_bbox.x1, _bbox.y1, _bbox.x2, _bbox.y2]
annotation['bbox'] = augmented_bbox
annotation['bbox_mode'] = BoxMode.XYXY_ABS
if 'segmentation' not in annotation:
return annotation
# transform the segmentation
segm = annotation['segmentation']
# for simplicity handle polygons only at now
assert isinstance(segm, list), 'Unsppourted segmentation format'
polygons = [Polygon(np.asarray(p).reshape(-1, 2)) for p in segm]
_polygons = augmentation.augment_polygons(
PolygonsOnImage(polygons, image_shape)
).remove_out_of_image().polygons
annotation['segmentation'] = [p.coords.reshape(-1) for p in _polygons]
return annotation
class AugTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name):
output_folder = os.path.join(cfg.OUTPUT_DIR, 'inference')
return COCOEvaluator(dataset_name, output_dir=output_folder)
@classmethod
def build_train_loader(cls, cfg):
augmentation = iaa.Sequential([
iaa.Fliplr(0.2),
iaa.Sometimes(
0.5,
iaa.Sequential([
iaa.Sometimes(
0.5,
iaa.AddToHueAndSaturation(
(-20, 20),
per_channel=True,
),
iaa.Grayscale(alpha=(0, .2))
),
iaa.Sometimes(
0.5,
iaa.WithBrightnessChannels(iaa.Add((-50, 50))),
iaa.EdgeDetect(alpha=(0, 0.3)),
),
])
)
])
mapper = Mapper(augmentation, is_train=True)
return build_detection_train_loader(cfg, mapper=mapper)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
mapper = Mapper(iaa.Noop(), is_train=False)
return build_detection_test_loader(cfg, dataset_name, mapper=mapper)
def train(config_file, output_dir, image_dir, extra_aug=True):
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.OUTPUT_DIR = output_dir
cfg.freeze()
trainer = AugTrainer(cfg)
ia_random.seed(42)
return trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train and evaluate based on config file.')
parser.add_argument('config', help='Specify config file', metavar='CONFIG_FILE')
parser.add_argument('images', help='Image root directroy', default=None)
parser.add_argument('-o', '--output', help='Output directory', default=None)
parser.add_argument('--gpu-per-machine', help='GPU number per machine', default=1)
# example: run the default mask rcnn training for coco datasets while images stored in /mnt/images
# python train_with_imgaug.py ./configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml /mnt/images
args = parser.parse_args()
config = str(pathlib.Path(args.config).resolve().absolute())
output = pathlib.Path(args.output or './output/' + datetime.now().strftime('%Y%m%dT%H%M'))
output.mkdir(parents=True, exist_ok=True)
output = str(output.resolve().absolute())
images = str(pathlib.Path(args.images).resolve().absolute())
setup_logger()
launch(train, args.gpu_per_machine, dist_url='auto', args=(config, output, images))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment