Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active September 15, 2019 06:27
Show Gist options
  • Save YimianDai/f961e3059e16c1810576f3465afcaf5e to your computer and use it in GitHub Desktop.
Save YimianDai/f961e3059e16c1810576f3465afcaf5e to your computer and use it in GitHub Desktop.
YOLO3DefaultTrainTransform

YOLO3DefaultTrainTransform 接受 VOCDetection 的输入, 完成两个功能:

  1. 一个是对 img 做 data augmentation, label 也做相应变换
  2. 另一个是按照 YOLOV3 的要求, 将 Human Labels for BBox 转变成 Model Labels for Anchors (targets)

__init__

        # in case network has reset_ctx to gpu
        self._fake_x = mx.nd.zeros((1, 3, height, width))
        net = copy.deepcopy(net)
        net.collect_params().reset_ctx(None)
  1. self._fake_x 是 (1, 3, H, W) 的 mx.ndarray, 全零, 因为是 fake data 嘛, 这是用于后面产生 anchors 的
        with autograd.train_mode():
            _, self._anchors, self._offsets, self._feat_maps, _, _, _, _ = net(self._fake_x)
  1. self._anchors 是 List of mx.ndarray, 每个元素大小为 (1, 1, 3, 2)
  2. self._offsets 是 List of mx.ndarray, 每个元素大小为 (1, 1, 128, 128, 2)
  3. self._feat_maps 是 List of mx.ndarray, 里面每个元素是 fake_featmap, 大小为 (1, 1, H_i, W_i)
        from ....model_zoo.yolo.yolo_target import YOLOV3PrefetchTargetGenerator
        self._target_generator = YOLOV3PrefetchTargetGenerator(
            num_class=len(net.classes), **kwargs)
  1. 创建 YOLOV3PrefetchTargetGenerator 类的实例, 该实例就是负责将 Label (Label for Human) 转化成 Target (Label for Model)

__call__

    def __call__(self, src, label):
        """Apply transform to training image/label."""
        # random color jittering
        img = experimental.image.random_color_distort(src)
        
        # random expansion with prob 0.5
        if np.random.uniform(0, 1) > 0.5:
            img, expand = timage.random_expand(img, fill=[m * 255 for m in self._mean])
            bbox = tbbox.translate(label, x_offset=expand[0], y_offset=expand[1])
        else:
            img, bbox = img, label   

        # resize with random interpolation
        h, w, _ = img.shape
        interp = np.random.randint(0, 5)
        img = timage.imresize(img, self._width, self._height, interp=interp)
        bbox = tbbox.resize(bbox, (w, h), (self._width, self._height))

        # random horizontal flip
        h, w, _ = img.shape
        img, flips = timage.random_flip(img, px=0.5)
        bbox = tbbox.flip(bbox, (w, h), flip_x=flips[0])            
        
        # to tensor
        img = mx.nd.image.to_tensor(img)
        img = mx.nd.image.normalize(img, mean=self._mean, std=self._std)   
        
        if self._target_generator is None:
            return img, bbox.astype(img.dtype)        
  1. src 是 (H, W, 3) 的 mx.ndarray, 传入的 label 是 (M, 6) 的 np.ndarray, 每一行存的是 [xmin, ymin, xmax, ymax, cls_id, difficult]
  2. 至此, img 是 (3, H, W) 的 mx.ndarray, bbox 是 (M, 6) 的 mx.ndarray
        # generate training target so cpu workers can help reduce the workload on gpu
        gt_bboxes = mx.nd.array(bbox[np.newaxis, :, :4])
        gt_ids = mx.nd.array(bbox[np.newaxis, :, 4:5])
        if self._mixup:
            gt_mixratio = mx.nd.array(bbox[np.newaxis, :, -1:])
        else:
            gt_mixratio = None
  1. gt_boxes 是 (1, M, 4) 的 mxnet.ndarray,是 [xmin, ymin, xmax, ymax] 的 Corner 编码,这些都是范围在 0 - 图像宽或高 之间的整数
  2. gt_ids 是 (1, M, 1) 的 mxnet.ndarray,是不带 Background 的类别标号
        objectness, center_targets, scale_targets, weights, class_targets = self._target_generator(
            self._fake_x, self._feat_maps, self._anchors, self._offsets,
            gt_bboxes, gt_ids, gt_mixratio)
        return (img, objectness[0], center_targets[0], scale_targets[0], weights[0],
                class_targets[0], gt_bboxes[0])  
  1. self._fake_x = mx.nd.zeros((1, 3, height, width))

YOLOV3PrefetchTargetGenerator 类实例的作用是将 Label for Human (gt_bboxes 和 gt_ids) 转化成 Label for Model 即 anchors 的 label (objectness, center_targets, scale_targets, class_targets), 具体流程是:

  1. 根据输入的 feature maps 大小和 anchor 的尺寸配置, 为 feat maps 上的每个点都生成相应的 anchors
  2. 根据 gt_bboxes 和 gt_ids, 计算 IoU, 挑选与给定的 gt bbox 最 match 的 anchor
  3. 根据 gt_bboxes 和 gt_ids 编码好 anchor 的 objness, bbox, cls_id

最后返回的就是 Label for Model 即 anchors 的 label (objectness, center_targets, scale_targets, class_targets) 了, 其中

  1. img 是 (3, H, W) 的 mx.ndarray
  2. objectness[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 1) 的 mx.ndarray, 在不用 mixup 的情况下, 匹配 anchor 的数值为 1
  3. center_targets[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray, 里面存的是在原图上 gt box 中心与所属 cell 左上角的归一化距离 (以 cell 长或宽的归一化距离)
  4. scale_targets[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray, 里面存的是在原图上 gt box 的长或宽相对于所匹配 anchor 长或宽的比例再取 log
  5. weights[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray
  6. class_targets[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, num_class) 的 mx.ndarray
  7. gt_bboxes[0] 是 (M, 4) 的 mxnet.ndarray,是 [xmin, ymin, xmax, ymax] 的 Corner 编码,这些都是范围在 0 - 图像宽或高 之间的整数
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment