Skip to content

Instantly share code, notes, and snippets.

@mpangrazzi
Created September 20, 2022 19:54
Show Gist options
  • Save mpangrazzi/d7d1214630b6ba3c98bb8e1872044b87 to your computer and use it in GitHub Desktop.
Save mpangrazzi/d7d1214630b6ba3c98bb8e1872044b87 to your computer and use it in GitHub Desktop.
Efficiently use a fill-mask pipeline on a large pretrained transformers model to categorise a list of words
from transformers import pipeline
pipe = pipeline("fill-mask", model="roberta-large", top_k=1)
def get_topic_using_fill_mask(words):
results = pipe(f"{words}. \n These words are about <mask>.")
return results[0]["token_str"].strip()
get_topic_using_fill_mask(["terra, mars, moon, saturn, jupyter"])
# prints "astronomy"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment