Skip to content

Instantly share code, notes, and snippets.

@wisnunugroho21
Last active August 22, 2021 08:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wisnunugroho21/3d69aa096e2ebcf54b886d63e53d42f1 to your computer and use it in GitHub Desktop.
Save wisnunugroho21/3d69aa096e2ebcf54b886d63e53d42f1 to your computer and use it in GitHub Desktop.
Masked Softmax PyTorch
import torch
from torch.nn.functional import softmax
def masked_softmax(input: torch.Tensor, bool_mask: torch.Tensor, dim: int = -1, dtype: torch.dtype = None) -> torch.Tensor:
min_type_value = torch.finfo(input.dtype).min
masked_value = input.masked_fill(bool_mask, min_type_value)
return softmax(masked_value, dim = dim, dtype = dtype)
## Example
# a = torch.rand(1, 3)
# print(a) ## tensor([[0.1218, 0.7097, 0.4001]])
# b = masked_softmax(a, a < 0.5)
# print(b) ## tensor([[0., 1., 0.]])
# c = masked_softmax(a, torch.tensor([True, False, False]))
# print(c) ## tensor([[0.0000, 0.5768, 0.4232]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment