Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active September 16, 2019 03:26
Show Gist options
  • Save YimianDai/9805ad13649f7793fe4d199f03a4a018 to your computer and use it in GitHub Desktop.
Save YimianDai/9805ad13649f7793fe4d199f03a4a018 to your computer and use it in GitHub Desktop.
train_yolo3.py
  1. parse_args
  2. get_dataset
  3. get_dataloader
  4. save_params
  5. validate
  6. train
  7. __main__

1. parse_args

2. get_dataset

3. get_dataloader

def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers, args):
    """Get dataloader."""
    width, height = data_shape, data_shape
    batchify_fn = Tuple(*([Stack() for _ in range(6)] + [Pad(axis=0, pad_val=-1) for _ in range(1)]))  # stack image, all targets generated
  1. batchify_fn 其实就是 Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Pad()), 这么多是因为经过 YOLO3DefaultTrainTransform 的 Dataset 返回的是 (img, objectness[0], center_targets[0], scale_targets[0], weights[0], class_targets[0], gt_bboxes[0]) 也是 7 个元素
  2. Pad() 函数会做 pad 和 stack 两件事, 所以会做 stack 的
    if args.no_random_shape:
        train_loader = gluon.data.DataLoader(
            train_dataset.transform(YOLO3DefaultTrainTransform(width, height, net, mixup=args.mixup)),
            batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
    else:
        transform_fns = [YOLO3DefaultTrainTransform(x * 32, x * 32, net, mixup=args.mixup) for x in range(10, 20)]
        train_loader = RandomTransformDataLoader(
            transform_fns, train_dataset, batch_size=batch_size, interval=10, last_batch='rollover',
            shuffle=True, batchify_fn=batchify_fn, num_workers=num_workers)

在 args.no_random_shape 为 True 的情况下

  1. YOLO3DefaultTrainTransform 代码解读 中可知, train_dataset.transform(YOLO3DefaultTrainTransform(width, height, net, mixup=args.mixup)) 返回的东西分别如下:
    1. img 是 (3, H, W) 的 mx.ndarray
    2. objectness[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 1) 的 mx.ndarray, 在不用 mixup 的情况下, 匹配 anchor 的数值为 1
    3. center_targets[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray, 里面存的是在原图上 gt box 中心与所属 cell 左上角的归一化距离 (以 cell 长或宽的归一化距离)
    4. scale_targets[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray, 里面存的是在原图上 gt box 的长或宽相对于所匹配 anchor 长或宽的比例再取 log
    5. weights[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray
    6. class_targets[0] 是 (H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, num_class) 的 mx.ndarray
    7. gt_bboxes[0] 是 (M, 4) 的 mxnet.ndarray,是 [xmin, ymin, xmax, ymax] 的 Corner 编码,这些都是范围在 0 - 图像宽或高 之间的整数
  2. 因此, 经过 batchify_fn 的 DataLoader 的会返回的是
    1. imgs 是 (B, 3, H, W) 的 mx.ndarray
    2. objectness 是 (B, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 1) 的 mx.ndarray, 在不用 mixup 的情况下, 匹配 anchor 的数值为 1
    3. center_targets 是 (B, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray, 里面存的是在原图上 gt box 中心与所属 cell 左上角的归一化距离 (以 cell 长或宽的归一化距离)
    4. scale_targets 是 (B, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray, 里面存的是在原图上 gt box 的长或宽相对于所匹配 anchor 长或宽的比例再取 log
    5. weights 是 (B, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray
    6. class_targets 是 (B, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, num_class) 的 mx.ndarray
    7. gt_bboxes 是 (B, M_max, 4) 的 mxnet.ndarray,是 [xmin, ymin, xmax, ymax] 的 Corner 编码,这些都是范围在 0 - 图像宽或高 之间的整数
    val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
    val_loader = gluon.data.DataLoader(
        val_dataset.transform(YOLO3DefaultValTransform(width, height)),
        batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers)
    return train_loader, val_loader
  1. val_dataset.transform(YOLO3DefaultValTransform(width, height)) 返回的只是
    1. img 是 (3, H, W) 的 mx.ndarray
    2. bbox 是 (M, 6) 的 np.ndarray
  2. 经过 batchify_fn 的 DataLoader 的会返回的是
    1. imgs 是 (B, 3, H, W) 的 mx.ndarray
    2. bboxes 是 (B, M_max, 6) 的 mxnet.ndarray

4. save_params

5. validate

6. train

        for i, batch in enumerate(train_data):
            batch_size = batch[0].shape[0]
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0) for it in range(1, 6)]
            gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)

假设我有两个 GPU:

  1. batch[0] 是 imgs, (B, 3, H, W) 的 mx.ndarray, data 是 List of mx.ndarray, 每个元素为 (B//2, 3, H, W) 的 mx.ndarray
  2. fixed_targets 是 List of List of mx.ndarray, 里面的每个 List of mx.ndarray 都是 objectness, center_targets, scale_targets, weights, class_targets, 个数分别是 B//2
  3. batch[6] 是 gt_bboxes, (B, M_max, 4) 的 mxnet.ndarray, gt_boxes 是 List of mx.ndarray, 每个元素为 (B//2, M_max, 4) 的 mx.ndarray
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets])
                    sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                autograd.backward(sum_losses)
            trainer.step(batch_size)
  1. 对于 for ix, x in enumerate(data): 这个循环, 因为我只有两个 GPU, 所以 ix 只会是 0 和 1, 因为是 Lazy 机制, 所以并行是在这里实现的, 循环分别会被放在两个 GPU 上
  2. x 是 YOLOV3 的 hybrid_forward 中的 x, (B // 2, 3, H, W) 的 mx.ndarray
  3. gt_boxes[ix] 是 YOLOV3 的 hybrid_forward 中的 gt_boxes, (B // 2, M_max, 4) 的 mxnet.ndarray,是 [xmin, ymin, xmax, ymax] 的 Corner 编码
  4. fixed_targets[0][ix] 是 YOLOV3 的 hybrid_forward 中的 obj_t, 是 (B // 2, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 1) 的 mx.ndarray
  5. fixed_targets[1][ix] 是 YOLOV3 的 hybrid_forward 中的 centers_t, 是 (B // 2, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray
  6. fixed_targets[2][ix] 是 YOLOV3 的 hybrid_forward 中的 scales_t, 是 (B // 2, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray
  7. fixed_targets[3][ix] 是 YOLOV3 的 hybrid_forward 中的 weights_t, 是 (B // 2, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, 2) 的 mx.ndarray
  8. fixed_targets[4][ix] 是 YOLOV3 的 hybrid_forward 中的 clas_t, 是 (B // 2, H_3 x W_3 x 3 + H_2 x W_2 x 3 + H_1 x W_1 x 3, num_class) 的 mx.ndarray

7. __main__

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment