Skip to content

Instantly share code, notes, and snippets.

@farhanhubble
Created December 3, 2022 04:46
Show Gist options
  • Save farhanhubble/c70b369c7b164e925f41ce31abfc54b6 to your computer and use it in GitHub Desktop.
Save farhanhubble/c70b369c7b164e925f41ce31abfc54b6 to your computer and use it in GitHub Desktop.
Disambiguating Classification Losses

Neural networks that perform classification(predict the class of an input) produce a vector of C raw numbers for every input, where C is the total number of classes, for example 10 for MNIST. This vector is called a logit.

If we feed an MNIST image for digit 3 to a classification model, it may produce the a logit vector like this (top row added to show the class ID's:

0 1 2 3 4 5 6 7 8 9
-121.2 -212.3 81.1 171.1 -55.0 132.5 -13.2 63.5 99.2 -10.9

We can take the element wise softmax() to get probabilities for each class id:

0 1 2 3 4 5 6 7 8 9
1.13e-127 3.10e-167 8.19e-40 1.0 6.39e-99 1.7e-17 9.11e-81 1.86e-47 5.94e-32 9.08e-80

The actual class id or target value 3 can be encoded into a one hot vector as:

0 1 2 3 4 5 6 7 8 9
0 0 0 1 0 0 0 0 0 0
  • For computing the loss with logits and targets, the loss function is simple cross entropy loss that computes the following: np.dot(target, -1*log(softmax(logits)).

  • For computing the loss with probabilities and targets, the probabilities need to be first converted to log-probabilities and then passed to the negative log likelihood function that simmply performs: np.dot(*target, -1*log-probabilities)

When performing binary classification, there are two classes(C=2) that are mutually exclusive. The network is desinged to output a single logit. The target class id is also either 0 or 1. The softmax() operation then reduces to a single sigmoid() evaluation.

  • For computing loss with logit and target, the suitable loss function is binary cross entropy with logits that still performs target * -1 *log(sigmoid(logit)) + (1-target) * -1 * log(1-sigmoid(logit))

  • For computing loss with probabilities, if the network has a sigmoid layer built into it, the loss function is binary cross entropy loss

@farhanhubble
Copy link
Author

Switch to public.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment