Skip to content

Instantly share code, notes, and snippets.

@dvsrepo
Last active June 1, 2021 20:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dvsrepo/3acdea98e45f8f1ee0be0830cc3dbace to your computer and use it in GitHub Desktop.
Save dvsrepo/3acdea98e45f8f1ee0be0830cc3dbace to your computer and use it in GitHub Desktop.
from transformers import pipeline
from datasets import load_dataset
import rubrix as rb
model = pipeline('zero-shot-classification', model="typeform/squeezebert-mnli")
dataset = load_dataset("ag_news", split='test')
# Labels are: 'World', 'Sports', 'Business', 'Sci/Tech'
labels = dataset.features["label"].names
for example in dataset:
prediction = model(example['text'], labels)
record = rb.TextClassificationRecord(
inputs=example["text"],
prediction=list(zip(prediction['labels'], prediction['scores'])),
annotation=labels[example["label"]],
)
rb.log(record, name="ag_news_zeroshot")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment