Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Generate one hot labels from integer labels in PyTorch
def make_one_hot(labels, C=2):
'''
Converts an integer label torch.autograd.Variable to a one-hot Variable.
Parameters
----------
labels : torch.autograd.Variable of torch.cuda.LongTensor
N x 1 x H x W, where N is batch size.
Each value is an integer representing correct classification.
C : integer.
number of classes in labels.
Returns
-------
target : torch.autograd.Variable of torch.cuda.FloatTensor
N x C x H x W, where C is class number. One-hot encoded.
'''
one_hot = torch.cuda.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
target = one_hot.scatter_(1, labels.data, 1)
target = Variable(target)
return target
@depthwise
Copy link

depthwise commented Nov 10, 2019

Note that PyTorch's one_hot expands the last dimension, so the resulting tensor is NHWC rather than PyTorch standard NCHW which your prediction is likely to come in. To turn it into NCHW, one would need to add .permute(0,3,1,2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment