Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created September 3, 2021 05:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/b3f1755d1e3dc841189a795ba3b469ba to your computer and use it in GitHub Desktop.
Save e96031413/b3f1755d1e3dc841189a795ba3b469ba to your computer and use it in GitHub Desktop.
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
Portions of the source code are from the OLTR project which
notice below and in LICENSE in the root directory of
this source tree.
Copyright (c) 2019, Zhongqi Miao
All rights reserved.
"""
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
import json
class BalancedSoftmax(_Loss):
"""
Balanced Softmax Loss
"""
def __init__(self, freq_path):
super(BalancedSoftmax, self).__init__()
with open(freq_path, 'r') as fd:
freq = json.load(fd)
freq = torch.tensor(freq)
self.sample_per_class = freq
def forward(self, input, label, reduction='mean'):
return balanced_softmax_loss(label, input, self.sample_per_class, reduction)
def balanced_softmax_loss(labels, logits, sample_per_class, reduction):
"""Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
Args:
labels: A int tensor of size [batch].
logits: A float tensor of size [batch, no_of_classes].
sample_per_class: A int tensor of size [no of classes].
reduction: string. One of "none", "mean", "sum"
Returns:
loss: A float tensor. Balanced Softmax Loss.
"""
spc = sample_per_class.type_as(logits)
spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
logits = logits + spc.log()
loss = F.cross_entropy(input=logits, target=labels, reduction=reduction)
return loss
def create_loss(freq_path):
print('Loading Balanced Softmax Loss.')
return BalancedSoftmax(freq_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment