Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saimadhu-polamuri/bdde9d92da757bc9374976e965ecaba7 to your computer and use it in GitHub Desktop.
Save saimadhu-polamuri/bdde9d92da757bc9374976e965ecaba7 to your computer and use it in GitHub Desktop.
def get_predictions(model, texts):
# Use the model's tokenizer to tokenize each input text
docs = [model.tokenizer(text) for text in texts]
# Use textcat to get the scores for each doc
textcat = model.get_pipe('textcat')
scores, _ = textcat.predict(docs)
# From the scores, find the label with the highest score/probability
predicted_labels = scores.argmax(axis=1)
predicted_class = [textcat.labels[label] for label in predicted_labels]
return predicted_class
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment