Skip to content

Instantly share code, notes, and snippets.

@willprice
Last active July 10, 2023 05:43
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save willprice/827d3b8fcf2e585cf9234ff10bbfa54d to your computer and use it in GitHub Desktop.
Save willprice/827d3b8fcf2e585cf9234ff10bbfa54d to your computer and use it in GitHub Desktop.
NN debugging tips

Resources

Debugging training

These problems and solutions are taken from TFCheck - a paper about a TF library for detecting common training errors. Below is a summary of the bugs they found and the techniques they developed to catch these bugs.

Problem 1: Untrained parameters

Sometimes one can forget to wire in a network module so its parameters never end up being updated. We can catch this bug by

  1. Storing a checkpoint at the start of training
  2. After a few iterations, check that the difference between the stored weights and the new model weights is greater than some threshold

Problem 2: Symmetric weights

When layer weights aren't initialised to be different their gradients will be the same and hence the weights will always be the same value. We can catch this by computing the variance of the weight matrices after initialisation and checking it is not close to zero.

Problem 3: Parameter divergence

Weights can grow to +∞/-∞ if the learning rate is too high or there is insufficient regularisation. These issues can be caught be checking the lower/upper quartile of the weights and ensuring it above/below a threshold every iteration

Problem 4: Unstable parameters

In deep networks, hidden layers' outputs can vary quite rapidly causing learning difficulties for subsequent layers. Conversely, the opposite can also happen where the layer parameters change too slowly. Both these issues can be identified by comparing the magnitude of the parameter gradients against the parameter themselves.

$$ -4 < \log_0 \left( \frac{\mathrm{mean}(|\nabla\theta|)}{\mathrm{mean}(|\theta|)} \right) < -1 $$

According to Leon Bottou, a good rule of thumb is to keep the ratio of parameter updates to parameter values at around 0.01, or -2 in the log_10 expression above.

Problem 5: Activation saturation

Bounded activation functions like sigmoid have a region in which they have informative gradients, outside this the gradients are very small. It's best to keep inputs to these activation functions within these active ranges. To check whether a neuron is suffering from saturation, the $\rho_B$ measure can be used. This measure first breaks the range of the activation function into $B$ bins, the historic activations of a given neuron are then binned and the saturation measure is computed as follows:

$$\rho_B = \frac{\sum^B_{b=1} |\bar{g}'b| N_B}{\sum^B{b=1} N_B}$$

  • $B$ - the total number of bins
  • $\bar{g}'_b$ - the scaled average of output values in the bin $b$ in the range $[-1, 1]$
  • $N_B$ - the number of outputs in bin $b$

$\rho_B$ tends to 1 as saturation decreases and 0 otherwise.

Implementation-wise, we can store historic activations for the last $N$ outputs and compute $\rho_B$, if the value tends to 1 then the neuron has become saturated.

Problem 6: Dead ReLUs

ReLUs only have non-zero gradients in the positive reals. It is possible for a ReLU to always output 0 (i.e. its input is always negative) which renders the neuron useless. This can be detected by keeping the last $N$ outputs of a ReLU function in a buffer and checking if all are equal to 0. If so then the neuron has probably died. One can keep track of the ratio of dead ReLUs per layer and if this exceeds a threshold raise an error.

Problem 7: Unable to fit a small sample

Any deep NN should be able to fit a small dataset ~10s of examples with ease just by memorising the data. If the model is incapable of doing this then it signals an issue with the training setup.

Problem 8: Zero loss

If the loss reaches 0 it is highly likely the network has overfitted the training data. Regularisation should be added to avoid this. Check your network does not go to 0 loss when fitting a small dataset--this verifies the efficacy of the regularisation.

Problem 9: Slowly decreasing loss

Suboptimal learning rates or a bad model will cause the training procedure to go very slowly. Computing the ratio of the current loss against the previous iteration's loss tells you how quickly training is proceeding. You probably also want a decaying average of this ratio as the loss value can be quite noisy across batches.

Problem 10: Diverging loss

Too high learning rates or incorrectly implemented loss functions can cause the loss to increase rather than decrease. One technique for detecting this is to record the lowest loss value seen so far and compare it to the current loss, if this ratio exceeds a threshold than throw an error.

Problem 11: Vanishing gradients

Some activation functions can have quite small gradients which when backpropagated have a compounding effect causing earlier layers to receive very small gradients. This effect causes the early layers to learn very slowly. One can compare the average magnitude of the gradients in the last layer to the first and if the log of this ratio is out of an acceptable range then throw an error.

$$\mathrm{min} < \log_10 \left( \frac{\mathrm{mean}(|\nabla\theta_L|)}{\mathrm{mean}(|\nabla\theta_0|)} < \mathrm{max}$$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment