Last active
March 8, 2024 10:07
-
-
Save deshwalmahesh/80614105a01554bcc6065c9e30291e8e to your computer and use it in GitHub Desktop.
EMD Loss with Confidence and Weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class WeightedEMDLossWithConfidencePenalty(nn.Module): | |
""" | |
Original Idea of why this Loss: | |
1. If the actual class was "0" and model predicted class "4", then I want it to penalise more than if it would have predicted it class "2": Addressed by the Vanilla EMD Loss itself | |
2. Penalization based on classes. For example, if the weights are [2,2,1,1,1] then it penalizes twice if the wrong prediction is from Class "0" or "1" vs when the wrong predictions were from the remaining classes: Addressed by class_weights | |
3. If the model is correctly predicting the right class but with a Low confidence, penalise it. For example if the True class was class 0" and my model predicted class "0" BUT I want to penalise it more when it predicted with a probability of 0.6 than when it predicted with a probability of 0.9: Addressed by the `correct_pred_low_prob_penalty` logic | |
4. If the model is Wrongly predicting the class and that too with a high confidence, penalise it even further. For example if the True class was class "0" and my model predicted class 3, so I want to penalise it more when it predicted the wrong class with a probability of 0.9 than when it predicted with a probability of 0.25: Addressed by the `incorrect_pred_high_prob_penalty` | |
1. Correct Prediction with a High confidence: Least Loss | |
2. Correct prediction with Low confidence: Loss is higher than above. | |
3. Incorrect prediction with a Low confidence: Loss is higher than the above 2 | |
4. Incorrect prediction with a High confidence: Loss is maximized | |
""" | |
def __init__(self, num_classes:int, r = 2, class_weights = None, correct_pred_low_prob_penalty = None, | |
incorrect_pred_high_prob_penalty = None, squared = True, summed = False, sanity_check = True): | |
""" | |
args: | |
num_classes: >= 2. Number of labels / classes in your Classificationtask | |
r: The norm (2 means L-2) norm to use | |
class_weights: None or 1-D Torch Tensor of length num_classes. Increase or decrease the loss for specific cases. Usually helps with Imbalance and Recall improvement of special classes | |
correct_pred_low_prob_penalty: [None or a float >=1.0 is good] Whether to penalise the Correct predictions when they are underconfident in predictions. Helps with Precision by increasing the probability of the Correct classes | |
incorrect_pred_high_prob_penalty: [None or a float >=1.0 is good] Whether to penalise the wrong prediction even further based on their confidence. Helps with Recall by Suppressing Wrong predictions probability | |
squared: Whether to compute (L1,2,..) Normalized Vanilla or Squared EMD loss as given in paper: https://arxiv.org/pdf/1611.05916.pdf | |
summed: Whether to add the loss for all the samples or take mean of whole batch | |
sanity_check: If sanity check, it'll be slowe but will look whether Logits are passed or softmaxed. Also will look for the Labels are One Hot encoded or not. Will do both internally. Note: It'll be a bit slower | |
""" | |
super(WeightedEMDLossWithConfidencePenalty, self).__init__() | |
if implicit_weighing and (class_weights is None): raise ValueError("`implicit_weighing` can only be used when `class_weights` are set") | |
self.num_classes = num_classes | |
self.r = r | |
self.class_weights = class_weights | |
self.CPLPP = correct_pred_low_prob_penalty | |
self.IPHPP = incorrect_pred_high_prob_penalty | |
self.squared = squared | |
self.summed = summed | |
self.sanity_check = sanity_check | |
def is_softmaxed(self, tensor): | |
""" | |
Just a small function to see if we have Logits or Softmaxed Probs | |
""" | |
return (tensor >= 0).all() and (tensor <= 1).all() and torch.isclose(tensor.sum(), torch.tensor(1.0)) | |
def is_one_hot(self, tensor): | |
""" | |
Small function to Sanity check on the One Hot encodings | |
""" | |
if tensor.dim() != 2: return False # Check if the tensor is 2D | |
row_sums = tensor.sum(dim=1) | |
if not torch.all(row_sums == 1): return False # Check if all rows sum to 1 | |
unique_values = tensor.unique() | |
if not torch.all((unique_values == 0) | (unique_values == 1)): return False # Check if all values are 0 or 1 | |
return True | |
def forward(self, logits, labels): | |
""" | |
Args: | |
logits: raw output (logits) for predicted labels of shape [BATCH × num_classes] | |
labels: Actual labels of shape [BATCH × num_classes] or [BATCH]. If it's a Batch, then it'll be One hot encoded to [BATCH × num_classes] | |
""" | |
if self.sanity_check: | |
if len(logits) > 1: labels = labels.squeeze() # Handle Batch of size 1 | |
if not self.is_one_hot(labels): labels = torch.nn.functional.one_hot(labels.long(), num_classes=self.num_classes) | |
if not self.is_softmaxed(logits): logits = torch.nn.functional.softmax(logits, dim=1) | |
assert logits.shape == labels.shape, f"Shape of the two distribution batches must be the same. Got {logits.shape} and {labels.shape}" | |
cum_logits = torch.cumsum(logits, dim=1) | |
cum_labels = torch.cumsum(labels, dim=1) | |
if self.squared: | |
emd = torch.square(cum_labels - cum_logits) | |
if self.summed: emd = emd.sum(axis = 1) | |
else: emd = torch.linalg.norm(cum_labels - cum_logits, ord=self.r, dim=1) ** (1 / self.r) | |
if self.class_weights is not None: emd *= self.class_weights[labels.argmax(dim=1)] | |
if (self.CPLPP is not None) or (self.IPHPP is not None): # If atleast any of 1 is activated | |
# Another alternative for this is to use the CrossEntropyLoss with weights in blend with alpha * original loss + (1-alpha * CrossEntropy) loss | |
# Would require experementation on how this performs against the below code based on tasks | |
pred_labels = logits.argmax(dim=1) | |
gt_labels = labels.argmax(dim=1) | |
correct_preds = (pred_labels == gt_labels) | |
if self.CPLPP is not None: # For CORRECT predictions, penalize more for Low Confidence | |
multiplier_correct = 1 - logits[correct_preds, pred_labels[correct_preds]] | |
emd[correct_preds] += (multiplier_correct * self.CPLPP) # We can't multiply as the values are <1 and multiplying with <1 values again will have opposite of the desired effect | |
if self.IPHPP is not None: # For INCORRECT predictions, penalize more for High Confidence | |
multiplier_incorrect = logits[~correct_preds, pred_labels[~correct_preds]] | |
emd[~correct_preds] += (multiplier_incorrect * self.IPHPP) | |
return emd.mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
TESTS:
---------
----------
---------
-------
------