Last active
May 15, 2021 04:02
-
-
Save jacobkimmel/4ccdc682a45662e514997f724297f39f to your computer and use it in GitHub Desktop.
Generate one hot labels from integer labels in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_one_hot(labels, C=2): | |
''' | |
Converts an integer label torch.autograd.Variable to a one-hot Variable. | |
Parameters | |
---------- | |
labels : torch.autograd.Variable of torch.cuda.LongTensor | |
N x 1 x H x W, where N is batch size. | |
Each value is an integer representing correct classification. | |
C : integer. | |
number of classes in labels. | |
Returns | |
------- | |
target : torch.autograd.Variable of torch.cuda.FloatTensor | |
N x C x H x W, where C is class number. One-hot encoded. | |
''' | |
one_hot = torch.cuda.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_() | |
target = one_hot.scatter_(1, labels.data, 1) | |
target = Variable(target) | |
return target |
It looks like this gist has some visits, so I would like to let you know that PyTorch has one_hot
in the nn.functional
package: torch.nn.functional.one_hot(tensor, num_classes=0)
Thanks @rodrigoberriel, I think the official function was added as part of the big torch==0.4.0
bump that came out a year or so after this gist was published. Folks should use that instead.
Note that PyTorch's one_hot
expands the last dimension, so the resulting tensor is NHWC
rather than PyTorch standard NCHW
which your prediction is likely to come in. To turn it into NCHW, one would need to add .permute(0,3,1,2)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I found the issue: Your code assumes that
labels
is already unsqueezed at dimension 1. However, your example does not a provide a tensor in that form.Generally speaking, I am not sure if the above is a reasonable assumption to make, and have included the unsqueezing in
make_one_hot
.