Skip to content

Instantly share code, notes, and snippets.

@insaneyilin
Created November 30, 2022 05:05
Show Gist options
  • Save insaneyilin/1080d7996ab13809a8e49bab6106b07f to your computer and use it in GitHub Desktop.
Save insaneyilin/1080d7996ab13809a8e49bab6106b07f to your computer and use it in GitHub Desktop.
log_softmax and logsumexp with mask
def masked_log_softmax(input, mask, dim=1):
masked_input = input * mask.float()
max_input = torch.max(masked_input, dim=dim, keepdim=True)[0]
exps = torch.exp(masked_input - max_input)
masked_exps = exps * mask.float()
masked_sums = masked_exps.sum(dim, keepdim=True)
zeros = (masked_sums == 0)
masked_sums += zeros.float()
masked_exps += 1e-6 # avoid zero input of log.
return torch.log(masked_exps / masked_sums)
def masked_log_sum_exp(input, keepdim=False, mask=None):
"""Numerically stable logsumexp on the last dim of `input`.
reference: https://github.com/pytorch/pytorch/issues/2591
Args:
input: A Variable with any shape.
keepdim: A boolean.
mask: A mask variable of type float. It has the same shape as `input`.
Valid entries are masked to ones.
Returns:
Equivalent of log(sum(exp(input), keepdim=keepdim)).
"""
if mask is not None:
mask = 1. - mask
max_offset = -1e7 * mask
else:
max_offset = 0.
s, _ = torch.max(input + max_offset, dim=-1, keepdim=True)
input_offset = input - s
if mask is not None:
input_offset.masked_fill_((mask > 1e-6), -float('inf'))
output = s + input_offset.exp().sum(dim=-1, keepdim=True).log()
if not keepdim:
output = output.squeeze(-1)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment