- 概述
- 代码解读
- 2.1
__init__
- 2.2
hybrid_forward
- 2.1
在 YOLOV3 中 YOLOV3Loss 实例被创建:
self._loss = YOLOV3Loss()
在 YOLOV3 中 YOLOV3Loss 实例被调用:
return self._loss(*(all_preds + all_targets))
def __init__(self, batch_axis=0, weight=None, **kwargs):
super(YOLOV3Loss, self).__init__(weight, batch_axis, **kwargs)
self._sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
self._l1_loss = gluon.loss.L1Loss()
def hybrid_forward(self, F, objness, box_centers, box_scales, cls_preds,
objness_t, center_t, scale_t, weight_t, class_t, class_mask):
输入:
- objness 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 1) 的 mx.ndarray,
- box_centers 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
- box_scales 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
- cls_preds 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, num_class) 的 mx.ndarray
- objness_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 1) 的 mx.ndarray, 1 表示是最匹配的 anchor, 为 0 表示 iou 数值小于 ignore thresh, -1 表示并非 iou 最高但是大于 ignore_iou_thresh 的 anchor
- center_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
- scale_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
- weight_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 2) 的 mx.ndarray
- class_t 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, num_class) 的 mx.ndarray
- class_mask 形状为 (B, H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, num_class) 的 mx.ndarray
# compute some normalization count, except batch-size
denorm = F.cast(
F.shape_array(objness_t).slice_axis(axis=0, begin=1, end=None).prod(), 'float32')
weight_t = F.broadcast_mul(weight_t, objness_t)
hard_objness_t = F.where(objness_t > 0, F.ones_like(objness_t), objness_t)
new_objness_mask = F.where(objness_t > 0, objness_t, objness_t >= 0)
obj_loss = F.broadcast_mul(
self._sigmoid_ce(objness, hard_objness_t, new_objness_mask), denorm)
center_loss = F.broadcast_mul(self._sigmoid_ce(box_centers, center_t, weight_t), denorm * 2)
- denorm 就是 H_1 x W_1 x num_anchors + ... + H_3 x W_3 x num_anchors, 感觉不需要那么复杂
- hard_objness_t 是 objness_t 中 大于 0 的 objness 会被设成 1, 其余 -1 和 0 会被保留, 因为 objness_t 中本来也只有 1 大于 0, 所以感觉没变
- new_objness_mask 是将 objness_t 中为 1 或 0 的对应位置都为 1, -1 的为 0, 就是一个 非 ignore anchor 的 mask, 作为 SigmoidBinaryCrossEntropyLoss 中的 sample_weight, 也就是 ignore anchor 的损失不计入, 为 0.