Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Continoulli with Logits
import torch
import torch.nn as nn
class ContinoulliWithLogitsLoss(nn.BCEWithLogitsLoss):
"""
Numerically stable implementation of the objective function defined in [1].
[1] https://arxiv.org/abs/1907.06845
"""
def forward(self, input, target):
bce = super(ContinoulliWithLogitsLoss, self).forward(input, target)
logZ = torch.log(self.Z(input))
if self.reduction == 'mean':
logZ = logZ.mean()
elif self.reduction == 'sum':
logZ = logZ.sum()
else:
return NotImplementedError
return bce + logZ
@staticmethod
def Z(x):
t = (0.5 * x).tanh()
return x.div(t + t.detach().eq(0).to(t)).clamp_min(2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment