Created
February 9, 2023 09:33
-
-
Save rohitdavas/5edd8692279fe85eb6a1a7e51235469a to your computer and use it in GitHub Desktop.
Balanced Binary cross entropy loss function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is simple implementation of a binary cross entropy. | |
Test cases : done. | |
""" | |
import torch | |
def balanced_BCE_loss(predictions: torch.Tensor, ground_truth: torch.Tensor, with_logits=False) -> torch.Tensor: | |
""" | |
ground_truth : (Batch_size, 1) | |
predictions : (Batch_size, 1) | |
""" | |
assert ground_truth.shape == predictions.shape, f"shape mismatch : ground_truth.shape = {ground_truth.shape}, predictions.shape = {predictions.shape}" | |
# calculate the number of positive and negative samples in the batch | |
total_samples = ground_truth.shape[0] # assuming n_positives + n_negatives = total_samples | |
# calculate the number of positive and negative samples in the batch | |
n_positives = torch.sum(ground_truth).float() + 1e-6 | |
n_negatives = total_samples- n_positives + 1e-6 | |
# calculate the weight for each sample | |
values = {} | |
values[0] = (1/n_negatives) * (total_samples/2.0) | |
values[1] = (1/n_positives) * (total_samples/2.0) | |
print(f"values : {values}") | |
# convert the weights to a tensor | |
weights = torch.zeros_like(ground_truth) | |
weights[ground_truth == 0] = values[0] | |
weights[ground_truth == 1] = values[1] | |
print(f"weights : {weights}") | |
if with_logits: | |
return torch.nn.functional.binary_cross_entropy_with_logits(input=predictions, target=ground_truth, weight=weights, reduction="mean") | |
else: | |
return torch.nn.functional.binary_cross_entropy(input=predictions, target=ground_truth, weight=weights, reduction="mean") | |
# tests : | |
def call(y_pred, y_true): | |
print(f""" | |
y_true = {y_true.tolist()}, shape = {y_true.shape} | |
y_pred = {y_pred.tolist()}, shape = {y_pred.shape} | |
loss = {balanced_BCE_loss(y_pred, y_true)} | |
""") | |
# create a tensor of 0s and 1s | |
y_true = torch.tensor([0, 1], dtype=torch.float32) | |
y_pred = torch.tensor([0.0, 0.0], dtype=torch.float32) | |
call(y_pred, y_true) | |
# create a tensor of 0s and 1s for a batch size of 4 | |
y_true = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
# when one is 1 and rest are 0 | |
y_true = torch.tensor([[1, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
# when two are 1 and rest are 0 | |
y_true = torch.tensor([[1, 1, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
# when three are 1 and rest are 0 | |
y_true = torch.tensor([[1, 1, 1, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
# when all are 1' | |
y_true = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
# when all are 0's and y_pred is 1 | |
y_true = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
y_true = torch.tensor([[1, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
y_true = torch.tensor([[1, 1, 0, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
y_true = torch.tensor([[1, 1, 1, 0]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) | |
y_true = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1) | |
y_pred = torch.tensor([[0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).reshape(-1, 1) | |
call(y_pred, y_true) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
formula for weights was taken from here : https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/structured_data/imbalanced_data.ipynb#scrollTo=cveQoiMyGQCo