Skip to content

Instantly share code, notes, and snippets.

@vene
Created December 29, 2019 14:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vene/2d57654643d25249d972f8be897cdddf to your computer and use it in GitHub Desktop.
Save vene/2d57654643d25249d972f8be897cdddf to your computer and use it in GitHub Desktop.
# compare FY entmax losses with (log)-likelihood objectives
# author: vlad niculae
import numpy as np
import torch
import matplotlib.pyplot as plt
from entmax import entmax_bisect, entmax_bisect_loss
def main(alpha=1.5):
plt.figure()
ts = torch.linspace(-4, 4, 50)
# since the implementation is for multiclass,
# we represent the score as z=[0, t] with y_true=[0, 1]
Z = torch.stack((torch.zeros_like(ts), ts)).t()
# (however, y_true is stored as indices)
y_true = torch.ones_like(ts, dtype=torch.long)
P = entmax_bisect(Z, alpha=alpha)
plt.title("binary loss for score=t if true_y=1 // alpha={}".format(alpha))
plt.xlabel("t")
plt.plot(ts, 1 - P[:, 1], label="1-p")
plt.plot(ts, -torch.log(P[:, 1]), label="-logp")
loss = entmax_bisect_loss(Z, y_true, alpha=alpha)
plt.plot(ts, loss, label="FY loss")
plt.legend()
plt.savefig(f"entmax-{alpha}.png")
if __name__ == '__main__':
main(alpha=1.5)
main(alpha=2)
main(alpha=1.05)
@vene
Copy link
Author

vene commented Dec 29, 2019

Approaching alpha->1, the Tsallis entmax FY loss approaches the negative log likelihood. But for any alpha>1, the nll becomes infinite at a point, while the FY loss is always nice.

entmax-2
entmax-1 5
entmax-1 05

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment