Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Created May 20, 2021 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/e690db9394b756c1d981ca3b654ae746 to your computer and use it in GitHub Desktop.
Save aribornstein/e690db9394b756c1d981ca3b654ae746 to your computer and use it in GitHub Desktop.
class Probabilities(ClassificationSerializer):
"""
A :class:`.Serializer` which applies a softmax to the model outputs
(assumed to be logits) and converts to a list.
"""
def serialize(self, sample: Any) -> Any:
return torch.softmax(sample, -1).tolist()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment