Skip to content

Instantly share code, notes, and snippets.

@rohit-gupta
Created February 19, 2024 19:46
Show Gist options
  • Save rohit-gupta/b7ab765840a2d52caf13ba2f9587a09f to your computer and use it in GitHub Desktop.
Save rohit-gupta/b7ab765840a2d52caf13ba2f9587a09f to your computer and use it in GitHub Desktop.
Shrinkmatch loss
def shrink_loss(pseudo_label, logits_u_s, conf_thresh):
removed_class_idx = []
loss_u_shrink_batch = 0
B, C = pseudo_label.shape
max_probs = pseudo_label.max(dim=-1)[0]
mask = pseudo_label.ge(conf_thresh).float()
sorted_prob_w, sorted_idx = pseudo_label.topk(C, dim=-1, sorted=True)
# organize logit_s same as the sorted_prob_w
sorted_logits_s = logits_u_s.gather(dim=-1, index=sorted_idx)
if mask.mean().item() == 1: # no uncertain samples to shrink
loss_u_shrink_batch = 0
else:
for b in range(B):
if max_probs[b] >= conf_thresh: # skip certain samples
continue
# iteratively remove classes to enhance confidence until satisfying the confidence threshold
for c in range(2, C):
# new confidence in the shrunk class space (classes ranging from 1 ~ (c-1) are removed)
sub_conf = sorted_prob_w[b, 0] / (sorted_prob_w[b, 0] + sorted_prob_w[b, c:].sum())
# break either when satifying the threshold or traversing to the final class (with smallest value)
if (sub_conf >= conf_thresh) or (c == C - 1):
sub_logits_s = torch.cat([sorted_logits_s[b, :1], sorted_logits_s[b, c:]], dim=0)
loss_u_shrink = F.cross_entropy(sub_logits_s[None, ], torch.zeros(1).long().cuda(), reduction='none')[0]
# for our loss reweighting principle 1
loss_u_shrink *= max_probs[b] * ((sub_conf >= conf_thresh))
loss_u_shrink_batch += loss_u_shrink
removed_class_idx.append(c)
break
return loss_u_shrink_batch, removed_class_idx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment