Skip to content

Instantly share code, notes, and snippets.

@DeNeutoy
Forked from sjmielke/logsumexp_experiments.py
Last active April 30, 2019 20:01
Show Gist options
  • Save DeNeutoy/a975809e9a5179c43fd3771590d55036 to your computer and use it in GitHub Desktop.
Save DeNeutoy/a975809e9a5179c43fd3771590d55036 to your computer and use it in GitHub Desktop.
from mpmath import mp
mp.dps = 100
import torch
# Take some list of values that shrink to be really small in log space:
torch_lps = torch.log_softmax(-torch.arange(20.0, dtype=torch.float64), dim=0)
mpmath_lps = -torch.arange(20.0, dtype=torch.float64)
Z = sum([mp.exp(mp.mpf(mpmath_lps[i].item())) for i in range(len(mpmath_lps))])
for i in range(len(mpmath_lps)):
mpmath_lps[i] = float(mp.mpf(mpmath_lps[i].item()) - mp.log(Z))
for (lps, name, conclusion) in [(torch_lps, "PyTorch's logsumexp/logsoftmax", "PyTorch LSE does end up at 0 as it should... but too early!\nmpmath's more exact calculations show that we overshoot!"), (mpmath_lps, "mpmath's 100-digit wide computation", "PyTorch LSE again fails to include the smallest, but now thinks that log Z < 0!\nmpmath shows that, no, that's not true, we still overshoot, but less.")]:
print("\n\nWhen normalizing with", name, "\n")
# Now take cumulative sums in log space using logsumexp:
sum_naive = torch.tensor(0.0, torch.float64).log()
sum_LSE = torch.tensor(0.0, torch.float64).log()
log_remainder = torch.tensor(1.0, torch.float64).log()
mp_remainder = mp.mpf(1.0)
print("add this -> log sums: sum_naive sum_LSE_all sum_log1p mpmath!")
print("----------------------------------------------------------------------------------------")
for i in range(len(lps)):
# Version 0 is naive log sum exp
sum_naive = torch.log(sum_naive.exp() + lps[i].exp())
# Version 1 reuses past computation results and implements binary LSE by log1p
if i == 0:
sum_log1p = lps[i]
else:
assert sum_log1p > lps[i]
sum_log1p = sum_log1p + (lps[i] - sum_log1p).exp().log1p()
# Version 2 uses torch's LSE for binary LSE
sum_LSE = torch.logsumexp(torch.stack([sum_LSE, lps[i]]), dim=0)
# Version 3 recomputes from scratch
sum_scratch = torch.logsumexp(lps[:i+1], dim=0)
# Version 4 uses the remainder for incremental computation instead!
log_remainder = lps[i] + (log_remainder - lps[i]).expm1().log()
# log_remainder = log_remainder + (-torch.exp(lps[i] - log_remainder)).log1p()
sum_rem = torch.log(-log_remainder.expm1())
# Version 5 is arbitrary precision floating math
mp_remainder = mp_remainder - mp.exp(mp.mpf(lps[i].item()))
# Compare!
print(f"... + {lps[i].exp().item():12.10f} {sum_naive.item():13.10f} {sum_scratch.item():13.10f} {sum_log1p.item():13.10f} {float(mp.log1p(-mp_remainder)):13.10f}".replace('0', '\033[2m0\033[0m'))
print("\n" + conclusion)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment