Instantly share code, notes, and snippets.

Embed
What would you like to do?
Troubleshooting Convolutional Neural Nets

Troubleshooting Convolutional Neural Networks

Intro

This is a list of hacks gathered primarily from prior experiences as well as online sources (most notably Stanford's CS231n course notes) on how to troubleshoot the performance of a convolutional neural network . We will focus mainly on supervised learning using deep neural networks. While this guide assumes the user is coding in Python3.6 using tensorflow (TF), it can still be helpful as a language agnostic guide.

Suppose we are given a convolutional neural network to train and evaluate and assume the evaluation results are worse than expected. The following are steps to troubleshoot and potentially improve performance. The first section corresponds to must-do's and generally good practices before you start troubleshooting. Every subsequent section header corresponds to a problem and the section is devoted to solving it. The sections are ordered to reflect "more common" issues first and under each header the "most-easily implementable" fixes come first.

Before Troubleshooting

The following are best practices to follow when coding your deep learning algorithm. Great sources on the topic are the lecture notes here and here as well as this review paper by Bengio.

  1. Use appropriate logging and meaningful variable names. In TF you will be able to appropriately keep track of different variables by name, and visualize the graph in tensorboard. Most importantly, every few training steps make sure you are logging relevant values, such as: step_number, accuracy, loss, learning_rate, and, if applicable, more specific metrics (such as mean_intersection_over_union - aka mean_iou for segmentation tasks). Plot a curve of loss versus step number.
  2. Make sure your network is wired correctly. Use tensorboard or other debugging techniques to make sure every operation in the graph is receiving the appropriate inputs and outputs. Also make sure you are appropriately pre-processing and pairing your data and labels before feeding them into the network.
  3. Implement data augmentation techniques. This is not always applicable, however, if you are dealing with images you will almost always see huge performance improvements if you apply simple data augmentation techniques- such as mirroring, rotating, randomly cropping and rescaling, adding noise, elastically deforming, etc...). TF has built-in functions for most of these operations.
  4. Use weight initilization and regularization for all layers. Avoid initializing weights to the same value, or worse yet to 0. Doing so will introduce symmetry and potentially vanishing gradients problems, and will in most cases lead to terrible results. In general, if you are having trouble with weight initialization consider adding Batch Normalization layers to your network.
  5. Make sure the regularization terms are not overwhelming the other terms in the loss function. Turn off regularization, find out the order of magnitude of the loss, then appropriately adjust the regularization weights. Make sure that as you are increasing the regularization strength the loss is also increasing.
  6. Try overfitting a small dataset. Turn off regularization/dropout/data augmentation/batch normalization, take a very small portion of your training set, and let the network train for a few epochs. Make sure you can achieve zero loss, if not, then something is most likely wrong. In some instances driving the loss to zero is particularly challenging, for example, in semantic image segmentation if your loss involves the cross entropy between the softmax-ed logits and ground truth labels for each pixel it may be really difficult to drive that down to 0. Instead, you should look to achieve close to 100% accuracy (computed by taking the argmax of the softmax-ed logits and comparing it with ground truth labels, see tf.metrics.accuracy for more detail.
  7. While overfitting the small dataset mentioned above, find a reasonable learning rate. Directly from Bengio's paper: The optimal learning rate is usually close to (by a factor of 2) the largest learning rate that does not cause divergence of the training criterion, an observation that can guide heuristics for setting the learning rate (Bengio, 2011), e.g., start with a large learning rate and if the training criterion diverges, try again with 3 times smaller learning rate, etc., until no divergence is observed.
  8. Perform gradient checks. Gradient checks are particularly relevant if you are using custom operations in your graph - i.e. not built-in TF ops. There are several "hacks" to implement an accurate gradient check.

Loss Value not Improving

If you train for several epochs and your loss value is not improving or if it is getting worse:

  1. Make sure you are using an appropriate loss and optimizing the correct tensor. A list of common loss functions is available here.
  2. Use a decent optimizer. A list of common optimizers is available here.
  3. Make sure your variables are training. To check this you may look want to look into tensorboard's histograms, or write a script the compute the norm (L1 or L^{\infty}) of each tensor at a few different training instances and print the names of those tensors with constant norms. If your variables are not training as expected please refer to this section.
  4. Adjust the initial learning rate and implement an appropriate learning rate schedule. This is perhaps the most impactful "fix". If your loss is getting worse, it is likely your learning rate is too large. On the other hand if your loss is nearly constant it is likely your learning is too small. Regardless, once you identify a valid initial learning rate you should implement a learning rate decay schedule. Some optimizers like ADAM implement learning rate decay internally, however, those learning rate updates are typically not aggressive enough and it may be a good idea to implement your own learning rate schedule on top of the optimizer's.
  5. Make sure you are not overfitting. There are several ways to overfit and several ways to avoid it. Make a plot of the loss vs number of training steps. Overfitting is explained here. Please refer to this section for possible ways to prevent overfitting.

Variable Not Training

Use tensorboard's histograms, or write a script the compute the norm of each tensor at a few different training instances and print the names of those tensors with constant norms. If a variable is not training as expected:

  1. Make sure it is treated as a trainable variable by TF. Look into TF GraphKeys for more details
  2. Make sure your gradient updates are not vanishing. If downstream variables (variables closer to the output) are training fine but upstream variables (those closer to the input) are nearly constant, you're probably running into a vanishing gradient problem. Please refer to this section.
  3. Make sure your ReLus are firing. If a large portion of your neurons are clamped to zero then you should revisit your weight initialization strategy, try using a less agressive learning rate schedule, and try decreasing weight decay regularization.

Vanishing/Exploding Gradients

  1. Consider using better weight initialization. This is especially relevant if the gradient updates are very small when training starts.
  2. Consider changing your activation functions. If you are using ReLus consider substituting them with leaky ReLu's or MaxOut activations. You should completely avoid sigmoid activations and generally stay away from tanh.
  3. If using a recurrent neural net consider using LSTM blocks. See the discussion here.

Overfitting

Overfitting is a scenario in which your network "memorizes" the training data. This is typically indicated by the gap between the training and validation accuracy curves - see Train/Val accuracy section here.

  1. Implement data augmentation techniques. Refer to the top section.
  2. Implement dropout. Dropout consists of randomly igonoring some neurons at each step during training. The contribution of those neurons during the forward pass is removed and they are not updated during the backward pass. Please refer here for more.
  3. Increase regularization.
  4. Implement Batch normalization. Plese refer here for more info.
  5. Implement validation-based early stopping. Overfitting may occur because the network trained for too many epochs. Early stopping helps eliminate this problem. Please refer here for more.
  6. If everything else fails, using a smaller network. This really should be your last resort and in fact these course notes caution against doing so.

More things you may try

  1. Consider using a weighted loss function. For instance in the context of semantic image segmentation, the network is asked to classify each pixel in the input image. Some classes may occur very rarely compared to other classes, in this case weighing the rare classes should show an improved mean_iou metric.
  2. Change your network architecture. Your network may be too deep or too shallow.
  3. Consider using an ensemble of models.
  4. Replace max/average-pooling layers with strided convolutions.
  5. Perform a thorough hyperparameter search.
  6. Change the random seed.
  7. If all else fails, acquire more data.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment