Skip to content

Instantly share code, notes, and snippets.

@dsantiago
Created April 5, 2021 01:35
Show Gist options
  • Save dsantiago/07f3d3fe3e5c35c420cb221c5213bbc0 to your computer and use it in GitHub Desktop.
Save dsantiago/07f3d3fe3e5c35c420cb221c5213bbc0 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
crit = torch.nn.BCELoss()
preds = torch.tensor([[1, 0.2, 0.3, 1, 1], [1, 0.1, 0.7, 1, 1]]).float() # n classes
targets = torch.tensor([[1, 1, 1, 0, 0], [1, 0, 1, 1, 1]]).float() # n classes 0 or 1
print(preds, targets)
print(crit(preds, targets))
#---
EPS = 1e-12
res = targets * torch.log(preds + EPS) + (1 - targets) * torch.log(1 - preds + EPS)
res *= -1
res[res >= -np.log(EPS)] = 100
print(res.sum() / np.product(targets.shape[:2]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment