Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
F1 score in PyTorch
def f1_loss(y_true:torch.Tensor, y_pred:torch.Tensor, is_training=False) -> torch.Tensor:
'''Calculate F1 score. Can work with gpu tensors
The original implmentation is written by Michal Haltuf on Kaggle.
Returns
-------
torch.Tensor
`ndim` == 1. 0 <= val <= 1
Reference
---------
- https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
- https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
- https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6
'''
assert y_true.ndim == 1
assert y_pred.ndim == 1 or y_pred.ndim == 2
if y_pred.ndim == 2:
y_pred = y_pred.argmax(dim=1)
tp = (y_true * y_pred).sum().to(torch.float32)
tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
epsilon = 1e-7
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
f1 = 2* (precision*recall) / (precision + recall + epsilon)
f1.requires_grad = is_training
return f1
@SuperShinyEyes

This comment has been minimized.

Copy link
Owner Author

@SuperShinyEyes SuperShinyEyes commented Oct 15, 2019

Tested with PyTorch v.1.1 with GPU

@SuperShinyEyes

This comment has been minimized.

Copy link
Owner Author

@SuperShinyEyes SuperShinyEyes commented Oct 15, 2019

class F1_Loss(nn.Module):
    '''Calculate F1 score. Can work with gpu tensors
    
    The original implmentation is written by Michal Haltuf on Kaggle.
    
    Returns
    -------
    torch.Tensor
        `ndim` == 1. epsilon <= val <= 1
    
    Reference
    ---------
    - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
    - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
    - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6
    - http://www.ryanzhang.info/python/writing-your-own-loss-function-module-for-pytorch/
    '''
    def __init__(self, epsilon=1e-7):
        super().__init__()
        self.epsilon = epsilon
        
    def forward(self, y_pred, y_true,):
        assert y_pred.ndim == 2
        assert y_true.ndim == 1
        y_true = F.one_hot(y_true, 2).to(torch.float32)
        y_pred = F.softmax(y_pred, dim=1)
        
        tp = (y_true * y_pred).sum(dim=0).to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        f1 = 2* (precision*recall) / (precision + recall + self.epsilon)
        f1 = f1.clamp(min=self.epsilon, max=1-self.epsilon)
        return 1 - f1.mean()

f1_loss = F1_Loss().cuda()
@pingaowang

This comment has been minimized.

Copy link

@pingaowang pingaowang commented Apr 11, 2020

Thank you for sharing!

@frannfuri

This comment has been minimized.

Copy link

@frannfuri frannfuri commented Apr 18, 2020

nice!

@Chiang97912

This comment has been minimized.

Copy link

@Chiang97912 Chiang97912 commented Sep 2, 2020

It works!

@sudonto

This comment has been minimized.

Copy link

@sudonto sudonto commented Nov 4, 2020

@SuperShinyEyes, in your code, you wrote assert y_true.ndim == 1, so this code doesn't accept the batch size axis?

@deltonmyalil

This comment has been minimized.

Copy link

@deltonmyalil deltonmyalil commented Mar 15, 2021

Thank you.

@fmellomascarenhas

This comment has been minimized.

Copy link

@fmellomascarenhas fmellomascarenhas commented Apr 1, 2021

@SuperShinyEyes, in your code, you wrote assert y_true.ndim == 1, so this code doesn't accept the batch size axis?

I believe it is because the code expects each batch to output the index of the label. This explains the line: y_true = F.one_hot(y_true, 2).to(torch.float32)

@vinitrinh

This comment has been minimized.

Copy link

@vinitrinh vinitrinh commented Aug 2, 2021

In this F1 "Loss", can this be backpropagated or is this just an eval metric?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment