Skip to content

Instantly share code, notes, and snippets.

@nasimrahaman
Created July 24, 2019 20:28
Show Gist options
  • Save nasimrahaman/26704172cc081d3f74567e65ece73c5b to your computer and use it in GitHub Desktop.
Save nasimrahaman/26704172cc081d3f74567e65ece73c5b to your computer and use it in GitHub Desktop.
Continoulli with Logits
import torch
import torch.nn as nn
class ContinoulliWithLogitsLoss(nn.BCEWithLogitsLoss):
"""
Numerically stable implementation of the objective function defined in [1].
[1] https://arxiv.org/abs/1907.06845
"""
def forward(self, input, target):
bce = super(ContinoulliWithLogitsLoss, self).forward(input, target)
logZ = torch.log(self.Z(input))
if self.reduction == 'mean':
logZ = logZ.mean()
elif self.reduction == 'sum':
logZ = logZ.sum()
else:
return NotImplementedError
return bce + logZ
@staticmethod
def Z(x):
t = (0.5 * x).tanh()
return x.div(t + t.detach().eq(0).to(t)).clamp_min(2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment