Skip to content

Instantly share code, notes, and snippets.

@justincosentino
Created March 10, 2020 19:27
Show Gist options
  • Save justincosentino/8e5045b1d9273dce60762935e16e1349 to your computer and use it in GitHub Desktop.
Save justincosentino/8e5045b1d9273dce60762935e16e1349 to your computer and use it in GitHub Desktop.
Selective loss function from SelectiveNet
def selective_loss(
targets: torch.Tensor,
f_out: torch.Tensor,
g_out: torch.Tensor,
target_coverage: float,
lmbda: int = 32,
) -> torch.Tensor:
"""
Calculates the selective loss for the given slice.
Args:
targets: target values.
f_out: output of the classification head.
g_out: output of the selection head.
target_coverage: target coverage.
lmbda: constant used to weight quadratic coverage penalty.
Returns:
The selective loss.
"""
def emp_sr(
targets: torch.Tensor, f_out: torch.Tensor, g_out: torch.Tensor
) -> torch.Tensor:
"""
Calculates empirical selective risk, as defined in equation (2) of the
SelectiveNet paper.
TODO: why do they not normalize this by the empirical coverage?
Args:
targets: target values.
f_out: output of the classification head.
g_out: output of the selection head.
Returns:
The empirical selective risk for the given slice.
"""
el_loss = torch.nn.functional.cross_entropy(f_out, targets, reduction="none")
return (el_loss * g_out.squeeze()).mean()
def emp_cov(g_out: torch.Tensor) -> torch.Tensor:
"""Calculates empirical coverage."""
return torch.mean(g_out)
def psi(a: torch.Tensor) -> torch.Tensor:
"""Quadratic penalty function."""
return torch.pow(torch.max(torch.zeros_like(a).to(DEVICE), a), 2)
return emp_sr(targets, f_out, g_out) + lmbda * psi(target_coverage - emp_cov(g_out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment