Skip to content

Instantly share code, notes, and snippets.

@discort
Created September 12, 2021 17:22
Show Gist options
  • Save discort/1036c9d5dcf242066480ab613c53c7c4 to your computer and use it in GitHub Desktop.
Save discort/1036c9d5dcf242066480ab613c53c7c4 to your computer and use it in GitHub Desktop.
Online hard example mining
import torch
from torch import Tensor
import torch.nn as nn
class OHEM(nn.Module):
"""
Online hard example mining.
Details: <https://arxiv.org/pdf/1604.03540.pdf>
"""
def __init__(self,
loss_fn: nn.Module,
ratio: float = 0.7,
reduction='mean'):
super(OHEM, self).__init__()
self.ratio = ratio
self.loss_fn = loss_fn
self.loss_fn.reduction = 'none'
def forward(self,
pred: Tensor,
target: Tensor,
dim: int = 1) -> Tensor:
loss = self.loss_fn(pred, target)
# if self.ratio == 1 or dim is None:
# return torch.mean(loss)
_, idxs = torch.topk(loss[:, dim], int(self.ratio * loss.size(0)))
return torch.mean(loss[idxs])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment