Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Generate one hot labels from integer labels in PyTorch
def make_one_hot(labels, C=2):
'''
Converts an integer label torch.autograd.Variable to a one-hot Variable.
Parameters
----------
labels : torch.autograd.Variable of torch.cuda.LongTensor
N x 1 x H x W, where N is batch size.
Each value is an integer representing correct classification.
C : integer.
number of classes in labels.
Returns
-------
target : torch.autograd.Variable of torch.cuda.FloatTensor
N x C x H x W, where C is class number. One-hot encoded.
'''
one_hot = torch.cuda.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
target = one_hot.scatter_(1, labels.data, 1)
target = Variable(target)
return target
@bentaculum
Copy link

bentaculum commented Jun 6, 2018

The exact example you use in your blog post does not work:

Traceback (most recent call last):
  File "test_one_hot.py", line 57, in <module>
    make_one_hot(labels)
  File "test_one_hot.py", line 48, in make_one_hot
    one_hot = torch.cuda.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
RuntimeError: dimension out of range (expected to be in range of [-2, 1], but got 2)

Do you have any idea why? I also tried

labels = torch.LongTensor(1,4,4) % 3
make_one_hot(labels)

but got a similar error:

Traceback (most recent call last):
  File "test_one_hot.py", line 57, in <module>
    make_one_hot(labels)
  File "test_one_hot.py", line 48, in make_one_hot
    one_hot = torch.cuda.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
RuntimeError: dimension out of range (expected to be in range of [-3, 2], but got 3)

@bentaculum
Copy link

bentaculum commented Jun 6, 2018

I found the issue: Your code assumes that labels is already unsqueezed at dimension 1. However, your example does not a provide a tensor in that form.
Generally speaking, I am not sure if the above is a reasonable assumption to make, and have included the unsqueezing in make_one_hot.

@rodrigoberriel
Copy link

rodrigoberriel commented Jul 9, 2019

It looks like this gist has some visits, so I would like to let you know that PyTorch has one_hot in the nn.functional package: torch.nn.functional.one_hot(tensor, num_classes=0)

@jacobkimmel
Copy link
Author

jacobkimmel commented Jul 10, 2019

Thanks @rodrigoberriel, I think the official function was added as part of the big torch==0.4.0 bump that came out a year or so after this gist was published. Folks should use that instead.

@depthwise
Copy link

depthwise commented Nov 10, 2019

Note that PyTorch's one_hot expands the last dimension, so the resulting tensor is NHWC rather than PyTorch standard NCHW which your prediction is likely to come in. To turn it into NCHW, one would need to add .permute(0,3,1,2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment