Skip to content

Instantly share code, notes, and snippets.

@teh

teh/bi-tempered.py

Last active Aug 27, 2019
Embed
What would you like to do?
# source: https://google.github.io/bi-tempered-loss/
# apache 2 licensed
import torch
def logT(u, t):
if t == 1:
return torch.log(u)
else:
return (torch.pow(u, 1.0 - t) - 1.0) / (1.0 - t)
def expT(u, t):
if t == 1:
return torch.exp(u)
else:
return torch.relu(torch.pow((1.0 + ((1.0 - t) * u)), (1.0 / (1.0 - t))))
def computeNormalization(activations, t, numIters = 5):
mu, _ = torch.max(activations, -1, keepdim=True)
normalizedActivationsStep0 = activations - mu
normalizedActivations = normalizedActivationsStep0
for i in range(numIters):
logtPartition = torch.sum(expT(normalizedActivations, t), -1, keepdim=True)
normalizedActivations = normalizedActivationsStep0 * torch.pow(logtPartition, 1 - t)
logtPartition = torch.sum(expT(normalizedActivations, t), -1, keepdim=True)
return mu - logT(1.0 / logtPartition, t)
def temperedSoftmax(activations, t, numIters = 5):
if t == 1.0:
normalizationConstants = torch.log(torch.sum(torch.exp(activations), -1, keepdim=True))
else:
normalizationConstants = computeNormalization(activations, t, numIters)
diff = activations - normalizationConstants
return expT(diff, t)
def temperedSigmoid(activations, t, numIters = 5):
activations2d = torch.reshape(activations, [-1, 1])
internalLogits = torch.cat([torch.zeros_like(activations2d), activations2d], 1)
return temperedSoftmax(internalLogits, t, numIters)
def bitemperedLogisticLoss(activations, labels, t1, t2, numIters = 5):
probabilities = temperedSoftmax(activations, t2, numIters)
lossValues = (
labels * (logT(labels + 1e-10, t1) - logT(probabilities, t1))
- ((1.0 / (2.0 - t1) * (torch.pow(labels, 2.0 - t1) - torch.pow(probabilities, 2.0 - t1)))
)
)
return torch.sum(lossValues, -1)
def bitemperedBinaryLogisticLoss(activations, labels, t1, t2, numIters = 5):
outShape = labels.shape
labels2d = torch.reshape(labels, [-1, 1])
activations2d = torch.reshape(activations, [-1, 1])
labels2d = torch.reshape(labels, [-1, 1])
zeroLabel2d = 1.0 - labels2d
internalLabels = torch.cat([zeroLabel2d, labels2d], 1)
internalLogits = torch.cat([torch.zeros_like(activations2d), activations2d], 1)
losses = bitemperedLogisticLoss(internalLogits, internalLabels, t1, t2, numIters)
return torch.reshape(losses, outShape)
def main():
N = 10
labels = torch.ones(N)
activations = torch.arange(N).float() / N
loss = bitemperedBinaryLogisticLoss(activations, labels, 0.2, 0.8)
print(labels)
print(loss)
print(temperedSigmoid(activations, 1.0))
assert torch.allclose(loss, torch.tensor([0.2956, 0.2636, 0.2336, 0.2056, 0.1798, 0.1562, 0.1347, 0.1154, 0.0980,
0.0827]), rtol=5e-3)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment