Skip to content

Instantly share code, notes, and snippets.

@rohitdavas
Created February 9, 2023 09:33
Show Gist options
  • Save rohitdavas/5edd8692279fe85eb6a1a7e51235469a to your computer and use it in GitHub Desktop.
Save rohitdavas/5edd8692279fe85eb6a1a7e51235469a to your computer and use it in GitHub Desktop.
Balanced Binary cross entropy loss function
"""
This is simple implementation of a binary cross entropy.
Test cases : done.
"""
import torch
def balanced_BCE_loss(predictions: torch.Tensor, ground_truth: torch.Tensor, with_logits=False) -> torch.Tensor:
"""
ground_truth : (Batch_size, 1)
predictions : (Batch_size, 1)
"""
assert ground_truth.shape == predictions.shape, f"shape mismatch : ground_truth.shape = {ground_truth.shape}, predictions.shape = {predictions.shape}"
# calculate the number of positive and negative samples in the batch
total_samples = ground_truth.shape[0] # assuming n_positives + n_negatives = total_samples
# calculate the number of positive and negative samples in the batch
n_positives = torch.sum(ground_truth).float() + 1e-6
n_negatives = total_samples- n_positives + 1e-6
# calculate the weight for each sample
values = {}
values[0] = (1/n_negatives) * (total_samples/2.0)
values[1] = (1/n_positives) * (total_samples/2.0)
print(f"values : {values}")
# convert the weights to a tensor
weights = torch.zeros_like(ground_truth)
weights[ground_truth == 0] = values[0]
weights[ground_truth == 1] = values[1]
print(f"weights : {weights}")
if with_logits:
return torch.nn.functional.binary_cross_entropy_with_logits(input=predictions, target=ground_truth, weight=weights, reduction="mean")
else:
return torch.nn.functional.binary_cross_entropy(input=predictions, target=ground_truth, weight=weights, reduction="mean")
# tests :
def call(y_pred, y_true):
print(f"""
y_true = {y_true.tolist()}, shape = {y_true.shape}
y_pred = {y_pred.tolist()}, shape = {y_pred.shape}
loss = {balanced_BCE_loss(y_pred, y_true)}
""")
# create a tensor of 0s and 1s
y_true = torch.tensor([0, 1], dtype=torch.float32)
y_pred = torch.tensor([0.0, 0.0], dtype=torch.float32)
call(y_pred, y_true)
# create a tensor of 0s and 1s for a batch size of 4
y_true = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
# when one is 1 and rest are 0
y_true = torch.tensor([[1, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
# when two are 1 and rest are 0
y_true = torch.tensor([[1, 1, 0, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
# when three are 1 and rest are 0
y_true = torch.tensor([[1, 1, 1, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
# when all are 1'
y_true = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
# when all are 0's and y_pred is 1
y_true = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
y_true = torch.tensor([[1, 0, 0, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
y_true = torch.tensor([[1, 1, 0, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
y_true = torch.tensor([[1, 1, 1, 0]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
y_true = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32).reshape(-1, 1)
y_pred = torch.tensor([[0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).reshape(-1, 1)
call(y_pred, y_true)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment