Skip to content

Instantly share code, notes, and snippets.

@khdlr
Created April 25, 2022 15:34
Show Gist options
  • Save khdlr/c0bb6a653ac00c6d6859e6dc773daf96 to your computer and use it in GitHub Desktop.
Save khdlr/c0bb6a653ac00c6d6859e6dc773daf96 to your computer and use it in GitHub Desktop.
import re
import torch.nn
import torch.nn.functional as F
def get_loss(loss_args):
loss_type = loss_args['type']
functional_style = re.search(r'(\w+)\((\w+)\)', loss_type)
args = dict()
if functional_style:
func, arg = functional_style.groups()
new_args = dict(loss_args)
if func == 'Summed':
new_args['type'] = arg
return sum_loss(get_loss(new_args))
if loss_type == 'BCE':
loss_class = torch.nn.BCEWithLogitsLoss
if 'pos_weight' in loss_args:
args['pos_weight'] = loss_args['pos_weight'] * torch.ones([])
elif loss_type == 'FocalLoss':
return focal_loss_with_logits
elif loss_type == 'AutoBCE':
return auto_weight_bce
elif loss_type == 'HybridLoss':
return hybrid_loss
else:
raise ValueError(f"No Loss of type {loss_type} known")
return loss_class(**args)
def focal_loss_with_logits(y_hat_log, y, gamma=2):
log0 = F.logsigmoid(-y_hat_log)
log1 = F.logsigmoid(y_hat_log)
gamma0 = torch.pow(torch.abs(1 - y - torch.exp(log0)), gamma)
gamma1 = torch.pow(torch.abs(y - torch.exp(log1)), gamma)
return torch.mean(-(1 - y) * gamma0 * log0 - y * gamma1 * log1)
def auto_weight_bce(y_hat_log, y):
with torch.no_grad():
beta = y.mean(dim=[2, 3], keepdims=True)
logit_1 = F.logsigmoid(y_hat_log)
logit_0 = F.logsigmoid(-y_hat_log)
loss = -(1 - beta) * logit_1 * y \
- beta * logit_0 * (1 - y)
return loss.mean()
def hybrid_loss(y_hat_log, y):
loss_seg = F.binary_cross_entropy_with_logits(y_hat_log[:, :1], y[:, :1])
loss_edge = auto_weight_bce(y_hat_log[:, 1:], y[:, 1:])
return 0.5 * (loss_seg + loss_edge)
def sum_loss(loss_fn):
def loss(prediction, target):
if type(prediction) is list:
losses = torch.stack([loss_fn(p, t) for p, t in zip(prediction, target)])
return torch.sum(losses)
else:
return loss_fn(prediction, target)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment