Skip to content

Instantly share code, notes, and snippets.

@noachr

noachr/blog.md Secret

Created June 16, 2019 06:00
Show Gist options
  • Save noachr/4ff7646a4846614e3810615c06499653 to your computer and use it in GitHub Desktop.
Save noachr/4ff7646a4846614e3810615c06499653 to your computer and use it in GitHub Desktop.

A fastai/Pytorch implementation of MixMatch

In this post, I will be discussing and implementing "MixMatch: A Holistic Approach to Semi-Supervised Learning", by Berthelot, Carlini, Goodfellow, Oliver, Papernot and Raffel [1]. Released in May 2019, MixMatch is a semi-supervised learning algorithm which has significantly outperformed previous approaches. This blog comes from discussions had within Dr. Ehsan Kamalinejad's machine learning research group at Cal State East Bay.

How much of an improvement is MixMatch? When trained on CIFAR10 with 250 labeled images, MixMatch outperforms the next best technique (Virtual Adversarial Training) by almost 25% on the error rate (11.08% vs 36.03%, for comparison the fully supervised case on all 50k images has an error rate of 4.13%). These are far from incremental results, and the technique shows the potential to dramatically improve the state of semi-supervised learning.

Semi-supervised learning is largely a battle against overfitting; when the labeled set is small it doesn't take a very large neural network to memorize the entire training set. The general idea behind nearly all semi-supervised approaches is to leverage unlabeled data as a regularizer on the training of labeled data. For a great overview of various semi-supervised learning methods, see this blog by Sebastian Ruder. Different techniques employ different forms of regularization, and the MixMatch paper divides these into three groups: entropy minimization, consistency regularization, and generic regularization. As all three forms of regularization have proved effective, the MixMatch algorithm contains features from each.

MixMatch is a combination and improvement upon several of these techniques that have come out in recent years, including: Mean Teacher [2], Virtual Adversarial Training [3], and Mixup [4]. At a high level, the idea of MixMatch is to label the unlabeled data using predictions from the model and then apply heavy regularization in several forms. The first is performing data augmentation several times and taking the average for the label prediction. These predictions are then 'sharpened' to reduce their entropy. Finally, Mixup is performed on the labeled and unlabeled sets.

I am aiming this post at those familiar with Pytorch, but not necessarily fastai. For a Jupyter notebook version of this post containing the full code needed to reproduce all the results see this repository.

fastai

Before diving into the paper, I'll briefly talk about fastai. Fastai is a library, built on Pytorch, which makes writing machine learning applications much easier and simpler. They also offer a terrific online course covering both fastai and deep learning in general. Compared to pure Pytorch, fastai dramatically reduces the amount of boilerplate code required to produce state of the art neural networks. Here we'll be using the data pipeline and training loop features of fastai.

https://gist.github.com/55cbce3963ed31e1cc978d47b28d2470

Components

Let's first describe the individual pieces needed to assemble MixMatch, and then at the end put them together to form the complete algorithm. Following the paper, we'll be using CIFAR10 and taking 500 randomly selected images as the labeled training set. The standard 10000 image test set is used for all accuracy measurements.

Data Augmentaion

Data augmention is a widely used consistency regularization technique, with it's biggest success (so far) found in the computer vision realm. The idea is to alter the input data while preserving its semantic label. For images, common augmentations include rotation, cropping, zooming, brightning, etc. -- all transformations which do not change the underlying content of the image. MixMatch takes this a step further by performing augmentation multiple times to produce multiple new images. The predictions of the model on these images are then averaged to produce a target for the unlabeled data. This makes the predictions more robust than using a single image. The authors found that just two augments were sufficient to see this benefit.

Fastai has an efficient transformation system which we'll utilize on the data. However, as its designed to produce only one augmentation per image and we will need several, we will start by modifying the default LabelList to emit multiple augments.

https://gist.github.com/e18d4f5c07993a9a6c85c3f96fd908fb

Fastai's data block api allows for flexibly loading, labeling, and collating nearly any form of data. However, it doesn't have a method to grab a subset of one folder and the entirety of another folder, which is required here. Thus, we'll subclass the ImageList class and add a custom method. We'll use fastai's get_transforms method with no arguments to use the default image transforms; these are flipping around the center y axis, rotation up to 10 degrees, zooming, lighting change, and warping. Fastai's transform system automatically randomizes the exact parameters of each transform when applied.

https://gist.github.com/389fa3692925b2f94a97d12b087dd923

Mixup

$$ \lambda \sim Beta(\alpha,\alpha)\\ \lambda' = max(\lambda,1-\lambda)\\ Mixup(a,b) = \lambda'*a + (1-\lambda')*b$$

Mixup was first introduced by Zhang, Cisse, Dauphin, and Lopez-Paz [4] in 2018 and and falls into the category of general or traditional regularization. Instead of passing single images to the model, Mixup performs a linear interpolation between two seperate training images and passes that to the model. The one hot encoded labels of the images are also interpolated, using the same $\lambda$ coefficient as the images. That coefficient is randomly drawn from the beta distribution, parameterized by the alpha. Typically, $\alpha$ needs to be tuned to the dataset. At small values of $\alpha$, the beta distrubtion has most of it's weight in the tails, close to 0 or 1. As $\alpha$ increases, the distributions becomes uniform and then increasingly spiked around .5. Thus, $\alpha$ can be seen as controlling the intensity of the mixup; small values result in only a small amount of mixup, while larger values bias towards maximum mixup (50/50). At the extremes, $\alpha=0$ results in no mixup at all; and as $\alpha\rightarrow\infty$, $\beta$ approaches a Dirac delta distribution centered at 0.5. The authors recommend starting with a value of .75, which as seen below still has most of the weight in the tails. The paper makes one modification to the original method, which is to set $\lambda$ to $max(\lambda,1-\lambda)$; this biases the mixup towards the original image.

https://gist.github.com/a5971d8ea7ba00e0c6d0a7157f3bdad4

png

https://gist.github.com/e09ddcfa2c9d0c5e2653cc85fe05bc7b

Sharpening

$$Sharpen(p,T)i := \dfrac{p{i}^{1/T}}{\sum_{j=1}^{L}{p_{j}^{1/T}}}$$

The authors sharpen the model's predictions on the unlabeled data with the above equation as a form of entropy minimization. If the temperature $T < 1$, the effect is to make the predictions more certain, and as $T$ drops towards zero the predictions approach a one-hot distribution (see figure below). This relatively simple step, which involves no learned parameters, turns out to be incredibly important to the algorithm. In an ablation study, the paper reports an accuracy reduction of over 16% when removing the sharpening step (setting $T$ to $1$).

The idea behind entropy minimization in semi-supervised learning is that the decision boundary of the classifier should not pass through high density regions of the data space. If this were the case, the boundary would split data that are very close together. In addition, small perturbations would result in large changes in predictions. As predictions near the decision boundary are more uncertain, entropy minimization seeks to make the model more confident in its predictions thus moving the boundary away from the data. While other approaches [3] add an entropy term to the loss, MixMatch directly lowers the entropy of the unlabeled targets via the equation above.

As an example of this technique, let's try a classification problem that's simpler and easier to visualize than CIFAR -- MNIST. We'll still take 500 random examples as the labeled training set, and reserve the rest as the unlabeled set. The full images are used for training, but we'll also reduce each image to two dimensions using tSNE for visualization. Training in a semi-supervised manner following the same approach as MixMatch with regards to the unlabeled data, we'll use the model itself to generate pseudo-labels. The model consists of just two convolution layers and a linear head. No mixup or data augmentation is used, so we can isolate the effects of entropy minimization. The loss function is also largely the same as MixMatch, using cross-entropy for the labeled data and mean squared error for the unlabeled data, see the loss section below for the rationale behind this. The upper image is trained without using sharpening and in the lower image the pseduo-labels were sharpened with $T=0.5$. Training each for ten epochs, the unsharpened model has a test accuracy of 80.1%, and the sharpened model has an accuracy of 90.7%. In the images below, colors correspond to predicted class, and marker size is inversely proportional to prediction confidence (smaller markers are more confident). As shown by the marker sizes, the unsharpened model has a lot of uncertainty, especially around the edges of the clusters, while the sharpened model is much more confident in its predictions.

No sharpening

SharpeningThe effect of sharpening on the semi-supervised training of MNIST. Images in MNIST were reduced to two dimensions using tSNE. Colors correspond to predicted class, and marker size is inversely proportional to prediction confidence (smaller markers are more confident). The upper image was trained $T=1$, and the lower image with $T=0.5$.

https://gist.github.com/e7eb7c67e59a3927c0556a47c153e709

https://gist.github.com/455843afd3bda4d87ad328aa21faa732

png

The Mixmatch Algorithm

Now with all the pieces in place, the full algorithm can be implemented. Here are the steps for a single training iteration:

  1. Supply a batch of labeled data with its labels, and a batch of unlabeled data
  2. Augment the labeled batch to produce a new training batch.
  3. Augment each image in the unlabeled batch $K$ times, to produce a total of $BatchSize * K$ new unlabeled examples.
  4. For each original image in the unlabeled batch, pass the $K$ augmented versions to the model. Average the model's predictions across the augments to produce a single pseudo-label for the augmented images.
  5. Sharpen the pseudo-labels.
  6. The augmented labeled dataset and it's labels form set $X$. The augmented unlabeled data and it's (predicted) labels form set $U$.
  7. Concatenate sets $U$ and $X$ into set $W$. Shuffle $W$.
  8. Form set $X'$ by applying mixup to sets $X$ and $|X|$ examples from $W$.
  9. Form set $U'$ by applying mixup to sets $U$ and the examples in $W$ that were not used in step 8.

Sets $X'$ (labeled mixup) and $U'$ (unlabeled mixup) are then passed to the model, and the loss is computed using the corresponding mixed-up labels.

The Model

We will use a wideresnet model with 28 layers and a growth factor of 2 to match the paper. I use fastai's included WRN implementation.

https://gist.github.com/7276df26f57406307e59cedad290013d

Loss

With data and model in hand, we'll now implement the final piece requried for training -- the loss function. The loss function is the summation of two terms; the labeled and unlabeled losses. The labeled loss uses standard cross entropy, however the unlabeled loss function is the $l_2$ loss instead. This is because the $l_2$ loss is much less sensitive to very incorrect predicitions. Cross entropy loss is unbounded, and as the model's predicted probability of the correct class goes to zero cross entropy goes to infinty. However with $l_2$ loss, since we are working with probabilities, the worst case is that the model predicts 0 when the target is 1 or vice versa; this results in a loss of 1. With the unlabeled targets coming from the model itself, the algorithm doesn't want to penalize incorrect predictions too harshly. The parameter $\lambda$ (l in the code since lambda is reserved) controls the balance between the two terms.

We'll make one slight departure from the paper by linearly ramping up the weight of the unlabeled loss over the first 2000 iterations (roughly 10 epochs). Before applying this rampup, I was having difficulty training the model and found it would collapse to a single prediction value very quickly. Since the predicted labels at the start of training are essentially random, it makes sense to delay the application of unlabled loss. By the time the weight of the unlabled loss becomes significant, the model should be making reasonably good predictions.

https://gist.github.com/393d37037f838d23aab478e5d73cbce3

Training

Before training, lets review the hyperparemeters that have been introduced.

Hyperparameter Description Value
$K$ Number of augments 2
$T$ Sharpening Temperature 0.5
$\alpha$ Beta dist. parameter 0.75
$\lambda$ Unlabeled loss weight 75

The authors of the paper claim that $T$ and $K$ should be relatively constant across most datasets, while $\alpha$ and $\lambda$ need to be tuned per set. We'll use the same hyperparameters as the paper's official implementation.

One implementation detail: the paper mentions that instead of learning rate annealing, it updates a second model with the exponentially moving average of the training model's parameters. This is yet another form of regularization, but is not essential to the algorithm. For those interested, there is code for training with an EMA model in the repository. However, I didn't find a significant benefit over learning rate scheduling, and in the name of simplicity we'll forgo EMA and use fastai's implementaion of the one cycle policy to schedule the learning and momentum rates.

We'll use fastai's callback system to write a method which handles most of the MixMatch steps. This method takes in batches from the labeled and unlabeled sets, gets the predicted labels, and then performs mixup.

https://gist.github.com/69bc7917adb8fa27be66050651e186b9

A fastai Learner object contains the dataloaders and the model, and is responsible for executing the training loop. It also has a lot of utility functions, such as learning rate finding and prediction interpretation.

https://gist.github.com/48578209d542bfccc41b6cbc329ac47b

Results

For reference, I ran these tests on a Google Compute Engine virtual machine with 16 CPUs and a single P100 GPU. The first step is to establish some baselines so that MixMatch's performance can be compared. First, I'll try the fully supervised case with all 50k training images.

https://gist.github.com/045a26596923be2649c632e10c1b0a34

<div>
    <style>
        /* Turns off some styling */
        progress {
            /* gets rid of default border in Firefox and Opera. */
            border: none;
            /* Needs to be in here for Safari polyfill so background images work as expected. */
            background-size: auto;
        }
        .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
            background: #F44336;
        }
    </style>
  <progress value='2' class='' max='100', style='width:300px; height:20px; vertical-align: middle;'></progress>
  2.00% [2/100 03:47<3:05:57]
</div>
epoch train_loss valid_loss accuracy time
0 1.191859 1.232437 0.570300 01:53
1 0.923196 0.926058 0.676400 01:54
<div>
    <style>
        /* Turns off some styling */
        progress {
            /* gets rid of default border in Firefox and Opera. */
            border: none;
            /* Needs to be in here for Safari polyfill so background images work as expected. */
            background-size: auto;
        }
        .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
            background: #F44336;
        }
    </style>
  <progress value='49' class='' max='390', style='width:300px; height:20px; vertical-align: middle;'></progress>
  12.56% [49/390 00:13<01:34 0.8813]
</div>

Next I will try training on just the 500 labeled images, with no unsupervised component.

https://gist.github.com/c6fe17c9f6f7eb19b4e59c9081f9a9d7

Finally I will train with MixMatch, using the learner defined in the previous section. Note that I now use learn.fit instead of learn.fit_one_cycle since MixMatch uses EMA instead of a learning rate schedule.

https://gist.github.com/7938674c1507e698e05d96c9b7698f74

Conclusion

While MixMatch clearly boasts impressive performance, the downside is the additional time cost in training. Compared to the fully supervised case, training MixMatch takes me approximately 4x longer. Some of this may be due to inefficiencies in my implementation, but generating multiple augmentations and then obtaining model predictions for labels has a significant cost, especially in the one GPU case.

While augmentation and sharpening are hugely beneficial, the paper's ablation study shows that the single most important component, error wise, is MixUp. This is also the most mysterious component in terms of why it works so well -- why should enforcing linearity in predictions between images help the model? Certainly it reduces memorization of training data, but so does data augmentation and to not nearly the same effect in this case. Even the original MixUp paper only provides informal arguments as to its efficacy; from that paper:

"We argue that this linear behaviour reduces the amount of undesirable oscillations when predicting outside the training examples. Also, linearity is a good inductive bias from the perspective of Occam’s razor, since it is one of the simplest possible behaviors" [4]

Other researches have expanded upon the idea; for example by mixing up intermediate states instead of the input [7], or using a neural network instead of the beta function to generate the mixup coefficient [6]. However, I am unable to find a solid theoretical justification; this is yet another technique that falls into the 'it just works' category. Certainly it would be difficult to draw a biological analogy -- humans hardly learn a concept by blending it with an unrelated concept.

That said, MixMatch is hugely promising and it will be interesting to see it applied to other domains beyond vision.

References

[1]: Berthelot, David, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin Raffel. “MixMatch: A Holistic Approach to Semi-Supervised Learning.” ArXiv:1905.02249 [Cs, Stat], May 6, 2019. http://arxiv.org/abs/1905.02249.

[2]: Tarvainen, Antti, and Harri Valpola. “Mean Teachers Are Better Role Models: Weight-Averaged Consistency Targets Improve Semi-Supervised Deep Learning Results.” ArXiv:1703.01780 [Cs, Stat], March 6, 2017. http://arxiv.org/abs/1703.01780.

[3]: Miyato, Takeru, Shin-ichi Maeda, Masanori Koyama, and Shin Ishii. “Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning.” ArXiv:1704.03976 [Cs, Stat], April 12, 2017. http://arxiv.org/abs/1704.03976.

[4]: Zhang, Hongyi, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz. “Mixup: Beyond Empirical Risk Minimization.” ArXiv:1710.09412 [Cs, Stat], October 25, 2017. http://arxiv.org/abs/1710.09412.

[5]: Polyak, Boris, and Anatoli Juditsky. “Acceleration of Stochastic Approximation by Averaging.” SIAM Journal on Control and Optimization 30 (July 1, 1992): 838–55. https://doi.org/10.1137/0330046.

[6]: Guo, Hongyu, Yongyi Mao, and Richong Zhang. “MixUp as Locally Linear Out-Of-Manifold Regularization,” n.d., 9.

[7]: Verma, Vikas, Alex Lamb, Christopher Beckham, Amir Najafi, Ioannis Mitliagkas, Aaron Courville, David Lopez-Paz, and Yoshua Bengio. “Manifold Mixup: Better Representations by Interpolating Hidden States,” June 13, 2018. https://arxiv.org/abs/1806.05236v7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment