Skip to content

Instantly share code, notes, and snippets.

@erogol
Last active January 29, 2023 21:02
Show Gist options
  • Save erogol/c37628286f8efdb62b3cc87aad382f9e to your computer and use it in GitHub Desktop.
Save erogol/c37628286f8efdb62b3cc87aad382f9e to your computer and use it in GitHub Desktop.
Online hard example mining PyTorch
import torch as th
class NLL_OHEM(th.nn.NLLLoss):
""" Online hard example mining.
Needs input from nn.LogSotmax() """
def __init__(self, ratio):
super(NLL_OHEM, self).__init__(None, True)
self.ratio = ratio
def forward(self, x, y, ratio=None):
if ratio is not None:
self.ratio = ratio
num_inst = x.size(0)
num_hns = int(self.ratio * num_inst)
x_ = x.clone()
inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()
for idx, label in enumerate(y.data):
inst_losses[idx] = -x_.data[idx, label]
#loss_incs = -x_.sum(1)
_, idxs = inst_losses.topk(num_hns)
x_hn = x.index_select(0, idxs)
y_hn = y.index_select(0, idxs)
return th.nn.functional.nll_loss(x_hn, y_hn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment