Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active September 16, 2019 22:48
Show Gist options
  • Save YimianDai/c96c2a998bbd700ea83160d907f465bf to your computer and use it in GitHub Desktop.
Save YimianDai/c96c2a998bbd700ea83160d907f465bf to your computer and use it in GitHub Desktop.
YOLOV3Loss

目录

  1. 概述
  2. 代码解读

在 YOLOV3 中 YOLOV3Loss 实例被创建:

        self._loss = YOLOV3Loss()

在 YOLOV3 中 YOLOV3Loss 实例被调用:

                return self._loss(*(all_preds + all_targets))
    def __init__(self, batch_axis=0, weight=None, **kwargs):
        super(YOLOV3Loss, self).__init__(weight, batch_axis, **kwargs)
        self._sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
        self._l1_loss = gluon.loss.L1Loss()
    def hybrid_forward(self, F, objness, box_centers, box_scales, cls_preds,
                       objness_t, center_t, scale_t, weight_t, class_t, class_mask):

输入:

  1. objness 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 1) 的 mx.ndarray,
  2. box_centers 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
  3. box_scales 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
  4. cls_preds 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, num_class) 的 mx.ndarray
  5. objness_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 1) 的 mx.ndarray, 1 表示是最匹配的 anchor, 为 0 表示 iou 数值小于 ignore thresh, -1 表示并非 iou 最高但是大于 ignore_iou_thresh 的 anchor
  6. center_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
  7. scale_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
  8. weight_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
  9. class_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, num_class) 的 mx.ndarray
  10. class_mask 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, num_class) 的 mx.ndarray
        # compute some normalization count, except batch-size
        denorm = F.cast(
            F.shape_array(objness_t).slice_axis(axis=0, begin=1, end=None).prod(), 'float32')
        weight_t = F.broadcast_mul(weight_t, objness_t)
        hard_objness_t = F.where(objness_t > 0, F.ones_like(objness_t), objness_t)      
        new_objness_mask = F.where(objness_t > 0, objness_t, objness_t >= 0)  
        obj_loss = F.broadcast_mul(
            self._sigmoid_ce(objness, hard_objness_t, new_objness_mask), denorm)  
        center_loss = F.broadcast_mul(self._sigmoid_ce(box_centers, center_t, weight_t), denorm * 2)        
  1. denorm 就是 H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 感觉不需要那么复杂
  2. hard_objness_t 是 objness_t 中 大于 0 的 objness 会被设成 1, 其余 -1 和 0 会被保留, 因为 objness_t 中本来也只有 1 大于 0, 所以感觉没变
  3. new_objness_mask 是将 objness_t 中为 1 或 0 的对应位置都为 1, -1 的为 0, 就是一个 非 ignore anchor 的 mask, 作为 SigmoidBinaryCrossEntropyLoss 中的 sample_weight, 也就是 ignore anchor 的损失不计入, 为 0.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment