Skip to content

Instantly share code, notes, and snippets.

@MathiasGruber
Created July 10, 2021 07:13
Show Gist options
  • Save MathiasGruber/21d95598dc5b45cdaf65cbd5d624fc30 to your computer and use it in GitHub Desktop.
Save MathiasGruber/21d95598dc5b45cdaf65cbd5d624fc30 to your computer and use it in GitHub Desktop.
Conversion from ordinal predictions to class labels
def prediction2label(pred: np.ndarray):
"""Convert ordinal predictions to class labels, e.g.
[0.9, 0.1, 0.1, 0.1] -> 0
[0.9, 0.9, 0.1, 0.1] -> 1
[0.9, 0.9, 0.9, 0.1] -> 2
etc.
"""
return (pred > 0.5).cumprod(axis=1).sum(axis=1) - 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment