Skip to content

Instantly share code, notes, and snippets.

@visualDust
Last active November 19, 2023 12:24
Show Gist options
  • Save visualDust/36a22ea4ef72cdb1c42aed79dec7615b to your computer and use it in GitHub Desktop.
Save visualDust/36a22ea4ef72cdb1c42aed79dec7615b to your computer and use it in GitHub Desktop.
One hot encoding that supports ignore label
def one_hot_encode(
mask: torch.Tensor,
num_classes: int,
ignored_label: Union[str, int] = "negative",
):
"""Convert the mask to a one-hot encoded representation by @visualDust
Args:
mask (torch.Tensor): indexed label image. Should types int
num_classes (int): number of classes
ignored_label (Union[str|int], optional): specify labels to ignore, or ignore by pattern. Defaults to "negative".
Returns:
torch.Tensor: one hot encoded tensor
"""
original_shape = mask.shape
for _ in range(4 - len(mask.shape)):
mask = mask.unsqueeze(0) # H W -> C H W -> B C H W, if applicable
# start to handle ignored label
# convert ignored label into positive index bigger than num_classes
if type(ignored_label) is int:
mask[mask == ignored_label] = num_classes
elif ignored_label == "negative":
mask[mask < 0] = num_classes
# check if mask image is valid
if torch.max(mask) > num_classes:
raise RuntimeError("class values must be smaller than num_classes.")
B, _, H, W = mask.shape
one_hot = torch.zeros(B, num_classes + 1, H, W)
one_hot.scatter_(1, mask, 1) # mark 1 on channel(dim=1) with index of mask
one_hot = one_hot[:, :num_classes] # remove ignored label(s)
for _ in range(len(one_hot.shape) - len(original_shape)):
one_hot.squeeze_(0) # B C H W -> H W -> C H W, if applicable
return one_hot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment