Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Simple multi-laber classification example with Pytorch and MultiLabelSoftMarginLoss (https://en.wikipedia.org/wiki/Multi-label_classification)
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
# (1, 0) => target labels 0+2
# (0, 1) => target labels 1
# (1, 1) => target labels 3
train = []
labels = []
for i in range(10000):
category = (np.random.choice([0, 1]), np.random.choice([0, 1]))
if category == (1, 0):
train.append([np.random.uniform(0.1, 1), 0])
labels.append([1, 0, 1])
if category == (0, 1):
train.append([0, np.random.uniform(0.1, 1)])
labels.append([0, 1, 0])
if category == (0, 0):
train.append([np.random.uniform(0.1, 1), np.random.uniform(0.1, 1)])
labels.append([0, 0, 1])
class _classifier(nn.Module):
def __init__(self, nlabel):
super(_classifier, self).__init__()
self.main = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Linear(64, nlabel),
)
def forward(self, input):
return self.main(input)
nlabel = len(labels[0]) # => 3
classifier = _classifier(nlabel)
optimizer = optim.Adam(classifier.parameters())
criterion = nn.MultiLabelSoftMarginLoss()
epochs = 5
for epoch in range(epochs):
losses = []
for i, sample in enumerate(train):
inputv = Variable(torch.FloatTensor(sample)).view(1, -1)
labelsv = Variable(torch.FloatTensor(labels[i])).view(1, -1)
output = classifier(inputv)
loss = criterion(output, labelsv)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.data.mean())
print('[%d/%d] Loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
$ python multilabel.py
[1/5] Loss: 0.092
[2/5] Loss: 0.005
[3/5] Loss: 0.001
[4/5] Loss: 0.000
[5/5] Loss: 0.000
@raghavgoyal14
Copy link

raghavgoyal14 commented Nov 14, 2017

Shouldn't the Line 9 comment be # (1, 1) => target labels 2 ?

@bartolsthoorn
Copy link
Author

bartolsthoorn commented Jan 10, 2018

@raghavgoyal14 yes, you're right. I hope it's clear from the labels.append([0, 0, 1]) :)

@simonhessner
Copy link

simonhessner commented Sep 18, 2018

How exactly would you evaluate your model in the end? The output of the network is a float value between 0 and 1, but you want 1 (true) or 0 (false) as prediction in the end. So you have to find a threshold for each label. How is this done?

@Renthal
Copy link

Renthal commented Oct 25, 2018

I think this code does not work as you'd expect.

As per PyTorch documentation https://pytorch.org/docs/stable/nn.html#multilabelmarginloss the target vector is NOT a multi-hot encoding:

(v.0.1.12) The criterion only considers the first non zero y[j] targets.
(v.0.4.1) The criterion only considers a contiguous block of non-negative targets that starts at the front.

And this can also be verified here https://github.com/pytorch/pytorch/blob/949559552004db317bc5ca53d67f2c62a54383f5/aten/src/THNN/generic/MultiLabelMarginCriterion.c at lines 57 and 65 for example (also please have a look at line 39 and 40 where the range of the target is checked).

In fact, the correct way of denoting a target for class 0+2 (example from line 7) should be to replace line 16:

labels.append([1, 0, 1])

with

labels.append([0,2,-1])

(as a side note, line 20 should have if category == (1, 1): to match the description at line 9)

@rchavezj
Copy link

rchavezj commented Nov 18, 2018

I have trouble coding out the accuracy since the prediction variable for normal one label classification requires the max. How do we work our way around this?

@rchavezj
Copy link

rchavezj commented Nov 28, 2018

Is 0.092 equivalent to 92% or 9.2% for the first iterative loss

@erobic
Copy link

erobic commented Feb 18, 2019

Thank you @Renthal. I just wasted 2 hours on this and finally read your comment.

The code in this gist is incorrect. As @Renthal said, the leftmost columns for each example should be the ground truth class indices. The remaining columns should be filled with -1. Of course, each example may belong to different number of classes.

@andreydung
Copy link

andreydung commented Mar 11, 2019

@erobic @Renthal note that he is using MultiLabelSoftMarginLoss, not MultiLabelMarginLoss.

@wj-Mcat
Copy link

wj-Mcat commented Apr 6, 2020

When I change label format with -1 padded. As is shown in below:

for i in range(10000):
    category = (np.random.choice([0, 1]), np.random.choice([0, 1]))

    if category == (1, 1):
        train.append([np.random.uniform(0.1, 1), np.random.uniform(0.1, 1)])
        # labels.append([1, 0, 1])
        labels.append([0, 2, -1])

    if category == (1, 0):
        train.append([np.random.uniform(0.1, 1), 0])
        # labels.append([0, 1, 0])
        labels.append([1, -1, -1])

    if category == (0, 1):
        train.append([0, np.random.uniform(0.1, 1)])
        # labels.append([0, 0, 1])
        labels.append([2, -1, -1])

    if category == (0, 0):
        train.append([np.random.uniform(0.1, 1), np.random.uniform(0.1, 1)])
        # labels.append([1, 0, 0])
        labels.append([0, -1, -1])

But, I get amazing loss value:

[1/5] Loss: -1262.730
[2/5] Loss: -7461.019
[3/5] Loss: -18611.219
[4/5] Loss: -34584.168
[5/5] Loss: -55333.562

Final Problems: how to decode output logits of multi-class model?

@wj-Mcat
Copy link

wj-Mcat commented Apr 7, 2020

I create custom multi-class loss function, but trained too slowly.

class MultilabelCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(MultilabelCrossEntropyLoss, self).__init__()

    def forward(self, source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

        source = source.sigmoid()

        score = -1. * target * source.log() - (1 - target) * torch.log(1-source)
        return score.sum()

I got the result:

[1/500] Loss: 1.067
[2/500] Loss: 0.815
[3/500] Loss: 0.722
[4/500] Loss: 0.664
[5/500] Loss: 0.622
[6/500] Loss: 0.591
[7/500] Loss: 0.566
[8/500] Loss: 0.546
[9/500] Loss: 0.529
[10/500] Loss: 0.515
[11/500] Loss: 0.503
[12/500] Loss: 0.492
[13/500] Loss: 0.483
[14/500] Loss: 0.475
[15/500] Loss: 0.468
[16/500] Loss: 0.461
[17/500] Loss: 0.456
[18/500] Loss: 0.450

Why ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment