Skip to content

Instantly share code, notes, and snippets.

@nasimrahaman
Last active November 16, 2023 04:54
Show Gist options
  • Save nasimrahaman/a5fb23f096d7b0c3880e1622938d0901 to your computer and use it in GitHub Desktop.
Save nasimrahaman/a5fb23f096d7b0c3880e1622938d0901 to your computer and use it in GitHub Desktop.
Pytorch instance-wise weighted cross-entropy loss
import torch
import torch.nn as nn
def log_sum_exp(x):
# See implementation detail in
# http://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
# b is a shift factor. see link.
# x.size() = [N, C]:
b, _ = torch.max(x, 1)
y = b + torch.log(torch.exp(x - b.expand_as(x)).sum(1))
# y.size() = [N, 1]. Squeeze to [N] and return
return y.squeeze(1)
def class_select(logits, target):
# in numpy, this would be logits[:, target].
batch_size, num_classes = logits.size()
if target.is_cuda:
device = target.data.get_device()
one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes)
.long()
.repeat(batch_size, 1)
.cuda(device)
.eq(target.data.repeat(num_classes, 1).t()))
else:
one_hot_mask = torch.autograd.Variable(torch.arange(0, num_classes)
.long()
.repeat(batch_size, 1)
.eq(target.data.repeat(num_classes, 1).t()))
return logits.masked_select(one_hot_mask)
def cross_entropy_with_weights(logits, target, weights=None):
assert logits.dim() == 2
assert not target.requires_grad
target = target.squeeze(1) if target.dim() == 2 else target
assert target.dim() == 1
loss = log_sum_exp(logits) - class_select(logits, target)
if weights is not None:
# loss.size() = [N]. Assert weights has the same shape
assert list(loss.size()) == list(weights.size())
# Weight the loss
loss = loss * weights
return loss
class CrossEntropyLoss(nn.Module):
"""
Cross entropy with instance-wise weights. Leave `aggregate` to None to obtain a loss
vector of shape (batch_size,).
"""
def __init__(self, aggregate='mean'):
super(CrossEntropyLoss, self).__init__()
assert aggregate in ['sum', 'mean', None]
self.aggregate = aggregate
def forward(self, input, target, weights=None):
if self.aggregate == 'sum':
return cross_entropy_with_weights(input, target, weights).sum()
elif self.aggregate == 'mean':
return cross_entropy_with_weights(input, target, weights).mean()
elif self.aggregate is None:
return cross_entropy_with_weights(input, target, weights)
@FoobarProtocol
Copy link

This is a great implementation. Beautiful job and bravo. Going to re-publish as a gist for transparency purposes for anybody that's wondering. Will give credit to the OG author for this. Keep up the good work. This has helped to serve me well in my journey currently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment