Skip to content

Instantly share code, notes, and snippets.

@deshwalmahesh
Last active March 8, 2024 10:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deshwalmahesh/80614105a01554bcc6065c9e30291e8e to your computer and use it in GitHub Desktop.
Save deshwalmahesh/80614105a01554bcc6065c9e30291e8e to your computer and use it in GitHub Desktop.
EMD Loss with Confidence and Weights
import torch
import torch.nn as nn
class WeightedEMDLossWithConfidencePenalty(nn.Module):
"""
Original Idea of why this Loss:
1. If the actual class was "0" and model predicted class "4", then I want it to penalise more than if it would have predicted it class "2": Addressed by the Vanilla EMD Loss itself
2. Penalization based on classes. For example, if the weights are [2,2,1,1,1] then it penalizes twice if the wrong prediction is from Class "0" or "1" vs when the wrong predictions were from the remaining classes: Addressed by class_weights
3. If the model is correctly predicting the right class but with a Low confidence, penalise it. For example if the True class was class 0" and my model predicted class "0" BUT I want to penalise it more when it predicted with a probability of 0.6 than when it predicted with a probability of 0.9: Addressed by the `correct_pred_low_prob_penalty` logic
4. If the model is Wrongly predicting the class and that too with a high confidence, penalise it even further. For example if the True class was class "0" and my model predicted class 3, so I want to penalise it more when it predicted the wrong class with a probability of 0.9 than when it predicted with a probability of 0.25: Addressed by the `incorrect_pred_high_prob_penalty`
1. Correct Prediction with a High confidence: Least Loss
2. Correct prediction with Low confidence: Loss is higher than above.
3. Incorrect prediction with a Low confidence: Loss is higher than the above 2
4. Incorrect prediction with a High confidence: Loss is maximized
"""
def __init__(self, num_classes:int, r = 2, class_weights = None, correct_pred_low_prob_penalty = None,
incorrect_pred_high_prob_penalty = None, squared = True, summed = False, sanity_check = True):
"""
args:
num_classes: >= 2. Number of labels / classes in your Classificationtask
r: The norm (2 means L-2) norm to use
class_weights: None or 1-D Torch Tensor of length num_classes. Increase or decrease the loss for specific cases. Usually helps with Imbalance and Recall improvement of special classes
correct_pred_low_prob_penalty: [None or a float >=1.0 is good] Whether to penalise the Correct predictions when they are underconfident in predictions. Helps with Precision by increasing the probability of the Correct classes
incorrect_pred_high_prob_penalty: [None or a float >=1.0 is good] Whether to penalise the wrong prediction even further based on their confidence. Helps with Recall by Suppressing Wrong predictions probability
squared: Whether to compute (L1,2,..) Normalized Vanilla or Squared EMD loss as given in paper: https://arxiv.org/pdf/1611.05916.pdf
summed: Whether to add the loss for all the samples or take mean of whole batch
sanity_check: If sanity check, it'll be slowe but will look whether Logits are passed or softmaxed. Also will look for the Labels are One Hot encoded or not. Will do both internally. Note: It'll be a bit slower
"""
super(WeightedEMDLossWithConfidencePenalty, self).__init__()
if implicit_weighing and (class_weights is None): raise ValueError("`implicit_weighing` can only be used when `class_weights` are set")
self.num_classes = num_classes
self.r = r
self.class_weights = class_weights
self.CPLPP = correct_pred_low_prob_penalty
self.IPHPP = incorrect_pred_high_prob_penalty
self.squared = squared
self.summed = summed
self.sanity_check = sanity_check
def is_softmaxed(self, tensor):
"""
Just a small function to see if we have Logits or Softmaxed Probs
"""
return (tensor >= 0).all() and (tensor <= 1).all() and torch.isclose(tensor.sum(), torch.tensor(1.0))
def is_one_hot(self, tensor):
"""
Small function to Sanity check on the One Hot encodings
"""
if tensor.dim() != 2: return False # Check if the tensor is 2D
row_sums = tensor.sum(dim=1)
if not torch.all(row_sums == 1): return False # Check if all rows sum to 1
unique_values = tensor.unique()
if not torch.all((unique_values == 0) | (unique_values == 1)): return False # Check if all values are 0 or 1
return True
def forward(self, logits, labels):
"""
Args:
logits: raw output (logits) for predicted labels of shape [BATCH × num_classes]
labels: Actual labels of shape [BATCH × num_classes] or [BATCH]. If it's a Batch, then it'll be One hot encoded to [BATCH × num_classes]
"""
if self.sanity_check:
if len(logits) > 1: labels = labels.squeeze() # Handle Batch of size 1
if not self.is_one_hot(labels): labels = torch.nn.functional.one_hot(labels.long(), num_classes=self.num_classes)
if not self.is_softmaxed(logits): logits = torch.nn.functional.softmax(logits, dim=1)
assert logits.shape == labels.shape, f"Shape of the two distribution batches must be the same. Got {logits.shape} and {labels.shape}"
cum_logits = torch.cumsum(logits, dim=1)
cum_labels = torch.cumsum(labels, dim=1)
if self.squared:
emd = torch.square(cum_labels - cum_logits)
if self.summed: emd = emd.sum(axis = 1)
else: emd = torch.linalg.norm(cum_labels - cum_logits, ord=self.r, dim=1) ** (1 / self.r)
if self.class_weights is not None: emd *= self.class_weights[labels.argmax(dim=1)]
if (self.CPLPP is not None) or (self.IPHPP is not None): # If atleast any of 1 is activated
# Another alternative for this is to use the CrossEntropyLoss with weights in blend with alpha * original loss + (1-alpha * CrossEntropy) loss
# Would require experementation on how this performs against the below code based on tasks
pred_labels = logits.argmax(dim=1)
gt_labels = labels.argmax(dim=1)
correct_preds = (pred_labels == gt_labels)
if self.CPLPP is not None: # For CORRECT predictions, penalize more for Low Confidence
multiplier_correct = 1 - logits[correct_preds, pred_labels[correct_preds]]
emd[correct_preds] += (multiplier_correct * self.CPLPP) # We can't multiply as the values are <1 and multiplying with <1 values again will have opposite of the desired effect
if self.IPHPP is not None: # For INCORRECT predictions, penalize more for High Confidence
multiplier_incorrect = logits[~correct_preds, pred_labels[~correct_preds]]
emd[~correct_preds] += (multiplier_incorrect * self.IPHPP)
return emd.mean()
@deshwalmahesh
Copy link
Author

TESTS:

emd_loss =  WeightedEMDLossWithConfidencePenalty(num_classes = 3, r = 2, class_weights = None, implicit_weighing = False, correct_pred_low_prob_penalty = None,
                                                 incorrect_pred_high_prob_penalty = None, sanity_check = True)

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.1,0.2,0.7]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1,0,0]]) 
logits = torch.Tensor([[0.1,0.7,0.2]])  
print(emd_loss(logits, labels))

labels = torch.Tensor([[1,0,0]])
logits = torch.Tensor([[0.2,0.1,0.7]]) 
print(emd_loss(logits, labels))

labels = torch.Tensor([[1,0,0]])
logits = torch.Tensor([[0.2,0.7,0.1]]) 
print(emd_loss(logits, labels))

labels = torch.Tensor([[1,0,0]])
logits = torch.Tensor([[0.7,0.1,0.2]]) 
print(emd_loss(logits, labels))

labels = torch.Tensor([[1,0,0]])
logits = torch.Tensor([[0.7,0.2,0.1]]) 
print(emd_loss(logits, labels))


# DEBUG for batches
print("-"*30, "DEBUG", "-"*30)
labels = torch.Tensor([2, 0])  
logits = torch.Tensor([[0.1,0.2,0.7], [0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 0, 1], [1, 0, 0]])  
logits = torch.Tensor([[0.1,0.2,0.7], [0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))


labels = torch.Tensor([0])  
logits = torch.Tensor([[0.1,0.2,0.7]])
print(emd_loss(logits, labels))

---------

emd_loss =  WeightedEMDLossWithConfidencePenalty(num_classes = 3, r = 2, class_weights = None, implicit_weighing = False, correct_pred_low_prob_penalty = 1,
                                                 incorrect_pred_high_prob_penalty = None, sanity_check = True)

print("-"*30, "Correct Pred Prob", "-"*30)

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.1,0.2,0.7]])
print("BASE - Incorrect - No Effect: ",emd_loss(logits, labels))


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.5,0.2,0.3]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.5,0.3,0.2]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.51,0.29,0.20]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.51,0.3,0.19]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.9,0.09,0.01]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.9,0.1,0.00]])
print("BEST: ", emd_loss(logits, labels))


# DEBUG for batches
print("-"*30, "Batch Size DEBUG", "-"*30)
labels = torch.Tensor([2, 0])  
logits = torch.Tensor([[0.1,0.2,0.7], [0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 0, 1], [1, 0, 0]])  
logits = torch.Tensor([[0.1,0.2,0.7], [0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))

labels = torch.Tensor([0])  
logits = torch.Tensor([[0.1,0.2,0.7]])
print(emd_loss(logits, labels))

----------

emd_loss =  WeightedEMDLossWithConfidencePenalty(num_classes = 3, r = 2, class_weights = None, implicit_weighing = False, correct_pred_low_prob_penalty = None,
                                                 incorrect_pred_high_prob_penalty = 1, sanity_check = True)

print("-"*30, "Incorrect Pred Prob", "-"*30)

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.7,0.2,0.1]])
print("BASE - Correct - No Effect: ",emd_loss(logits, labels))


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.3, 0.5, 0.2]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.2, 0.5, 0.3]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.3, 0.2, 0.5]])
print(emd_loss(logits, labels))


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.49,0.26,0.25]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.49,0.49,0.02]])
print("BEST: ", emd_loss(logits, labels))


# DEBUG for batches
print("-"*30, "Batch Size DEBUG", "-"*30)
labels = torch.Tensor([2, 0])  
logits = torch.Tensor([[0.1,0.2,0.7], [0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 0, 1], [1, 0, 0]])  
logits = torch.Tensor([[0.1,0.2,0.7], [0.2, 0.3, 0.5]])
print(emd_loss(logits, labels))

labels = torch.Tensor([0])  
logits = torch.Tensor([[0.1,0.2,0.7]])
print(emd_loss(logits, labels))

---------

emd_loss =  WeightedEMDLossWithConfidencePenalty(num_classes = 3, r = 2, class_weights = None, implicit_weighing = False, correct_pred_low_prob_penalty = 0.5,
                                                 incorrect_pred_high_prob_penalty = 0.5, sanity_check = True)


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.7,0.2,0.1]])
print("Correct Pred: ",emd_loss(logits, labels))

labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.9,0.1,0.00]])
print(emd_loss(logits, labels))


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.1,0.2,0.7]])
print("Incorrect Pred ",emd_loss(logits, labels))


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.49,0.49,0.02]])
print(emd_loss(logits, labels))

-------

emd_loss =  WeightedEMDLossWithConfidencePenalty(num_classes = 3, r = 2, class_weights = torch.Tensor([2,1,2]), 
                                                 implicit_weighing = False, correct_pred_low_prob_penalty = None,
                                                 incorrect_pred_high_prob_penalty = None, sanity_check = True)


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.34,0.33,0.33]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 1, 0]])
logits = torch.Tensor([[0.33,0.33,0.34]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 1, 0]])
logits = torch.Tensor([[0.33,0.34,0.33]])
print(emd_loss(logits, labels))

------

emd_loss =  WeightedEMDLossWithConfidencePenalty(num_classes = 3, r = 2, class_weights = torch.Tensor([2,1,1]), 
                                                 implicit_weighing = True, correct_pred_low_prob_penalty = None,
                                                 incorrect_pred_high_prob_penalty = None, sanity_check = True)


labels = torch.Tensor([[1, 0, 0]])
logits = torch.Tensor([[0.34,0.33,0.33]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 1, 0]])
logits = torch.Tensor([[0.33,0.33,0.34]])
print(emd_loss(logits, labels))

labels = torch.Tensor([[0, 1, 0]])
logits = torch.Tensor([[0.33,0.34,0.33]])
print(emd_loss(logits, labels))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment