Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active May 22, 2024 03:21
Show Gist options
  • Save YimianDai/b4dfb38c0f36dd5553dcaaa9799d4ac7 to your computer and use it in GitHub Desktop.
Save YimianDai/b4dfb38c0f36dd5553dcaaa9799d4ac7 to your computer and use it in GitHub Desktop.
MixupDetection

首先需要明确的是 MixupDetection 这个类是在 VOCDetection 类之上的一个 wrapper,因为 mixup 这个技巧/方法可以简单的认为是一个 Data Augmentation 的手段。传统的 Data Augmentation 手段:裁剪、翻转 / 旋转、尺度变化,mixup 的做法在于将两幅图像按照随机权重相加。

MixupDetection 的实现考虑了两幅图像的大小可能是不同的,因此 mixup 后的图像为最大的图像,因此代码中 mixup 图像由下得到

        height = max(img1.shape[0], img2.shape[0])
        width = max(img1.shape[1], img2.shape[1])
        mix_img = mx.nd.zeros(shape=(height, width, 3), dtype='float32')
        mix_img[:img1.shape[0], :img1.shape[1], :] = img1.astype('float32') * lambd
        mix_img[:img2.shape[0], :img2.shape[1], :] += img2.astype('float32') * (1. - lambd)

在实现了图像的混合后,还要对 label 做相应修改。这是个 Detection Dataset 的 wrapper,Detection Dataset 比如 VOCDetection,返回的 label 的大小是 M x 6,M 是其中的 Object 数目。这里的 lambd 是需要添加到 label 中去的,因此就放到每个 object label array 的最后面。对于图片来说,M x 6 的 label 矩阵也就变成了 M x 7,在 label 后面添加一列的代码如下。最后 mixup 的图像对应着原来的两幅图像的 label,所以要把这两个 label vstack 起来,具体用途在 Loss 那边,见 MixSoftmaxCrossEntropyLoss

        y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd)))
        y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1. - lambd)))
		mix_label = np.vstack((y1, y2))

Loss 那边的具体做法是 网络对 mixup image 的输出 pred,分别根据 label1label2 计算 cross entropy loss,然后把 loss 按照 lambd 相加。论文作者张宏毅在知乎回答中提到,这种不混合 label,单独计算 loss 然后相加作为最终 loss 的做法,由于 cross-entropy loss 的性质,这种做法和把 label 线性加权是等价的。

最后思考下,为什么这种对 raw input 做 weighted sum 会有效?其实这里面有两个层次的问题:

  1. 为什么 Data Augmentation 会有效?
  2. 为什么对 raw input 做 weighted sum 会有效?

将 Data Augmentation 理解为控制模型复杂度

这段文字参考自论文作者张宏毅在知乎的回答

  1. 首先,Data Augmentation 虽然可以理解成是在增加 training data,但是 Data Augmentation 增加出来的 training data 并不是真正符合机器学习假设的在真实的 data distribution 上 i.i.d. 抽样出来的。在机器学习理论中,假设 training data 都是从真实的 data distribution 上i.i.d. 抽样出来的,而事实上我们 Data Augmentation 的手段都是我们拍脑袋想的,并不是真实的 data distribution。因此,不能简单地将 Data Augmentation 理解为增加 training data。
  2. 那么为什么 Data Augmentation 会很有用呢?流行的解释是 “增强模型对某种变换的 invariance”。这句话反过来说,就是机器学习里经常提到的 “减少模型估计的 variance”,也就是控制了模型的复杂度(是不是很像 正则化?)。L2 正则化、dropout 等等也都是在控制模型复杂度,只不过它们没有考虑数据本身的分布,而 data augmentation 属于更加机智的控制模型复杂度的方法
  3. 为什么降低模型估计的 variance 会提高 generalization?这其实是因为本身真实分布的 distribution 的复杂度其实是低于 NN 的复杂度的,让 NN 学到一个复杂度更低的函数一直是一个提高 generalization 的方式,这也就是为什么 L2 正则化会有作用的原因。

为什么对 raw input 做 weighted sum 会有效?

通过 data-augmentation 来让 NN 在 “空白区域” 学到一个简单的线性插值函数,大大降低了无数据覆盖空间的复杂度

以前都是假设 feature space 是“光滑”的,现在 mixup 的意思,对 raw data 做 interpolation 就相当于在数据分布的上做 interpolation。(这点还是对数据分布采样来理解)

transet 可以看为高维空间中分布的一堆散点,通过 mixup 造出了一大批位于 trainset 散点之间的新点出来,扩充数据的同时,使散点更密集,这样 model 在拟合 trainset 时不容易过拟合,使 model 更准确.

但这个为什么会 work,还需要更多的理解。这也就是为什么 mixup 目前对 分类比较好,但对 检测 和 分割 还不明显,甚至会掉点。

目前,有两种实现

  1. gluon-cv/gluoncv/data/mixup/detection.py
  2. ResidualAttentionNetwork-pytorch/Residual-Attention-Network/train_mixup.py

事实上,第一种实现是做 Detection 的,而第二种实现是做 Classification 的。前者的实现是在 Dataset 阶段操作,后者的实现是在 DataLoader 返回 batch 后再 batch 内操作。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment