Skip to content

Instantly share code, notes, and snippets.

@benwu232
Created August 25, 2017 06:50
Show Gist options
  • Save benwu232/1fbf1cd6b637810f5d57902fa6d4ef1b to your computer and use it in GitHub Desktop.
Save benwu232/1fbf1cd6b637810f5d57902fa6d4ef1b to your computer and use it in GitHub Desktop.
weight matrix loss for pytorch
def one_hot(size, index):
""" Creates a matrix of one hot vectors.
```
import torch
import torch_extras
setattr(torch, 'one_hot', torch_extras.one_hot)
size = (3, 3)
index = torch.LongTensor([2, 0, 1]).view(-1, 1)
torch.one_hot(size, index)
# [[0, 0, 1], [1, 0, 0], [0, 1, 0]]
```
"""
y_onehot = torch.LongTensor(*size).fill_(0)
y_onehot = Variable(y_onehot, volatile=index.volatile)
ones = Variable(torch.LongTensor(index.size()).fill_(1))
y_onehot = y_onehot.scatter_(1, index.view(-1,1), ones.view(-1,1))
return y_onehot
#weight_matrix is an N*N matrix which describes the weights between classes
class WeightMatrixLoss(torch.nn.Module):
def __init__(self, weight_matrix=None):
super().__init__()
#self.register_buffer('weight_matrix', weight_matrix)
self.weight_matrix = weight_matrix
def forward(self, p_onehot, target):
batch_size = len(target)
target = target.cpu()
t_onehot = one_hot(p_onehot.size(), target)
t = t_onehot.unsqueeze(1).cuda()
#p_onehot = p_onehot.cpu()
p = p_onehot.unsqueeze(2)
ce = -torch.bmm(t.float(), p)
#ce = torch.squeeze(ce, 1)
ce = ce.view((1, -1))
_, predict_value = torch.max(p_onehot.data, 1)
weight_line = np.zeros(batch_size, dtype=np.float32)
#weight_matrix = self.weight_matrix.numpy()
np_t = target.data.numpy()
np_p = predict_value.cpu().view(-1).numpy()
for k in range(batch_size):
weight_line[k] = self.weight_matrix[np_t[k]][np_p[k]]
weight_line = Variable(torch.from_numpy(weight_line).view((-1, 1))).cuda()
wce = torch.mm(ce, weight_line).view(-1)
return (wce / batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment