Skip to content

Instantly share code, notes, and snippets.

@vfmatzkin
Created September 1, 2021 01:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfmatzkin/891943594a94f4207e389f235287b6ce to your computer and use it in GitHub Desktop.
Save vfmatzkin/891943594a94f4207e389f235287b6ce to your computer and use it in GitHub Desktop.
Implementation of the Brabandere Loss term, used in https://arxiv.org/abs/1708.02551
def brabandere_loss(pred, alpha=1, beta=1, gamma=1e-3, dv=.05, dd=2):
"""
Implementation of the proposed loss in:
Semantic Instance Segmentation with a Discriminative Loss Function (https://arxiv.org/abs/1708.02551)
in which, the loss is a sum of three terms:
1) A intra-cluster attraction force.
2) A inter-cluster repealling force.
3) A regularization term for the mean codes.
This loss in unsupervised, so only batches of codes for each cluster are needed.
:param pred: List of batches of embeddings from which the loss will be calculated.
:param alpha: L_var (intra-cluster) loss weight.
:param beta: L_dist (inter-cluster) loss weight.
:param gamma: Regularization loss weight
:return:
"""
C = len(pred) # Number of clusters
o_loss = torch.zeros(3, dtype=torch.float, device='cuda:0') # Lvar, Ldist, Lreg
coeffs = torch.tensor([alpha, beta, gamma], dtype=torch.float, device='cuda:0')
for c, cluster in enumerate(pred): # Each cluster
Nc = len(cluster)
mean_cluster_a = torch.mean(cluster, dim=0)
o_loss[0] = torch.tensor(0., device='cuda:0') # Lvar
for emb in cluster: # Each embedding
# Sum only if it's far from the cluster
o_loss[0] += 1 / (C * Nc) * torch.pow(torch.max(torch.tensor(0., device='cuda:0'), torch.norm(mean_cluster_a - emb, 2) - dv),
2)
# Ldist
for j in range(c + 1, C): # The other clusters
mean_cluster_b = torch.mean(pred[j], dim=0)
o_loss[1] += 1 / (C * C - C) * torch.pow(
torch.max(torch.tensor(0., device='cuda:0'), 2 * dd - torch.norm(mean_cluster_a - mean_cluster_b, 2)), 2)
o_loss[2] += 1 / C * torch.norm(mean_cluster_a, 2) # Lreg
return torch.dot(coeffs, o_loss) # weighted loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment