Skip to content

Instantly share code, notes, and snippets.

@pcyin
Created September 7, 2018 14:29
Show Gist options
  • Save pcyin/b027ffec9b1bc1b87ba02286b55c2484 to your computer and use it in GitHub Desktop.
Save pcyin/b027ffec9b1bc1b87ba02286b55c2484 to your computer and use it in GitHub Desktop.
Pytorch masked `log_sum_exp`
def log_sum_exp(inputs, keepdim=False, mask=None):
"""Numerically stable logsumexp on the last dim of `inputs`.
reference: https://github.com/pytorch/pytorch/issues/2591
Args:
inputs: A Variable with any shape.
keepdim: A boolean.
mask: A mask variable of type float. It has the same shape as `inputs`.
**ATTENTION** invalid entries are masked to **ONE**, not ZERO
Returns:
Equivalent of log(sum(exp(inputs), keepdim=keepdim)).
"""
if mask is not None:
mask = 1. - mask
max_offset = -1e7 * mask
else:
max_offset = 0.
s, _ = torch.max(inputs + max_offset, dim=-1, keepdim=True)
inputs_offset = inputs - s
if mask is not None:
inputs_offset.masked_fill_(mask.byte(), -float('inf'))
outputs = s + inputs_offset.exp().sum(dim=-1, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(-1)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment