Skip to content

Instantly share code, notes, and snippets.

@MathiasGruber
Created July 2, 2021 06:38
Show Gist options
  • Save MathiasGruber/ef706cc4ede23b239024fec818b201d4 to your computer and use it in GitHub Desktop.
Save MathiasGruber/ef706cc4ede23b239024fec818b201d4 to your computer and use it in GitHub Desktop.
Ordinal regression loss function for
def ordinal_regression(predictions: List[List[float]], targets: List[float]):
"""Ordinal regression with encoding as in https://arxiv.org/pdf/0704.1028.pdf"""
# Create out modified target with [batch_size, num_labels] shape
modified_target = torch.zeros_like(predictions)
# Fill in ordinal target function, i.e. 0 -> [1,0,0,...]
for i, target in enumerate(targets):
modified_target[i, 0:target+1] = 1
return nn.MSELoss(reduction='none')(predictions, modified_target).sum(axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment