Skip to content

Instantly share code, notes, and snippets.

@cwhy
Last active July 23, 2020 09:32
Show Gist options
  • Save cwhy/f661be65f8b96e461bf8397c59c7c898 to your computer and use it in GitHub Desktop.
Save cwhy/f661be65f8b96e461bf8397c59c7c898 to your computer and use it in GitHub Desktop.
Pytorch tensor plays
import torch
def cum_softmax(t: torch.Tensor) -> torch.Tensor:
# t shape: ..., sm_d, sm_d is the dim to reduce
tmax = t.cummax(dim=-1)[0].unsqueeze(-2)
denominator = (t.unsqueeze(-1) - tmax).exp()
# shape: ..., sm_d, csm_d
numerator = denominator.sum(dim=-2, keepdim=True)
return numerator / denominator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment