Created
September 3, 2021 05:26
-
-
Save e96031413/b3f1755d1e3dc841189a795ba3b469ba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Copyright (c) Facebook, Inc. and its affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
Portions of the source code are from the OLTR project which | |
notice below and in LICENSE in the root directory of | |
this source tree. | |
Copyright (c) 2019, Zhongqi Miao | |
All rights reserved. | |
""" | |
import torch | |
from torch.nn.modules.loss import _Loss | |
import torch.nn.functional as F | |
import json | |
class BalancedSoftmax(_Loss): | |
""" | |
Balanced Softmax Loss | |
""" | |
def __init__(self, freq_path): | |
super(BalancedSoftmax, self).__init__() | |
with open(freq_path, 'r') as fd: | |
freq = json.load(fd) | |
freq = torch.tensor(freq) | |
self.sample_per_class = freq | |
def forward(self, input, label, reduction='mean'): | |
return balanced_softmax_loss(label, input, self.sample_per_class, reduction) | |
def balanced_softmax_loss(labels, logits, sample_per_class, reduction): | |
"""Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`. | |
Args: | |
labels: A int tensor of size [batch]. | |
logits: A float tensor of size [batch, no_of_classes]. | |
sample_per_class: A int tensor of size [no of classes]. | |
reduction: string. One of "none", "mean", "sum" | |
Returns: | |
loss: A float tensor. Balanced Softmax Loss. | |
""" | |
spc = sample_per_class.type_as(logits) | |
spc = spc.unsqueeze(0).expand(logits.shape[0], -1) | |
logits = logits + spc.log() | |
loss = F.cross_entropy(input=logits, target=labels, reduction=reduction) | |
return loss | |
def create_loss(freq_path): | |
print('Loading Balanced Softmax Loss.') | |
return BalancedSoftmax(freq_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment