Skip to content

Instantly share code, notes, and snippets.

@davidefiocco
Last active July 7, 2022 17:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davidefiocco/1c9e437de7b31e81bf2b8fecbe1d63ed to your computer and use it in GitHub Desktop.
Save davidefiocco/1c9e437de7b31e81bf2b8fecbe1d63ed to your computer and use it in GitHub Desktop.
Example prodigy recipe to use a zero-shot classifier to pre-classify examples when performing labeling for text classification (see https://support.prodi.gy/t/can-one-leverage-zero-shot-classifiers-for-textcat-tasks/4885)
{"text":"Spam spam lovely spam!"}
{"text":"I like scrambled eggs."}
{"text":"I prefer spam!"}
# Usage:
# python -m prodigy textcat.zero-shot -F .\textcat_zero_shot.py my_dataset dataset.jsonl facebook/bart-large-mnli --label SPAM,EGGS
from typing import Iterable, List
import prodigy
from prodigy.components.loaders import JSONL
from prodigy.components.sorters import prefer_high_scores
from prodigy.util import split_string
from tqdm import tqdm
from transformers import pipeline
class ZeroShotClassifier(object):
def __init__(self, labels: List[str], model: str):
self.pipeline = pipeline("zero-shot-classification", model=model)
self.labels = labels
def __call__(self, stream: Iterable[dict]):
for eg in tqdm(stream):
result = self.pipeline(eg["text"], self.labels)
eg["label"] = result["labels"][0]
score = result["scores"][0]
# format score to have it visualized in the UI
eg["meta"] = {"score": f"{score:.3f}"}
yield (score, eg)
@prodigy.recipe(
"textcat.zero-shot",
dataset=("The dataset to use", "positional", None, str),
source=("The source data as a JSONL file", "positional", None, str),
model=("Model name (from the Huggingface hub)", "positional", None, str),
label=("One or more comma-separated labels", "option", "l", split_string),
)
def textcat_zero_shot(dataset: str, source: str, label: List[str], model: str):
# Load the stream from a JSONL file and return a generator that yields a
# dictionary for each example in the data.
stream = JSONL(source)
# Load the zero-shot classification model and run it with labels
model = ZeroShotClassifier(labels=label, model=model)
stream = prefer_high_scores(model(stream))
return {
"view_id": "classification", # Annotation interface to use
"dataset": dataset, # Name of dataset to save annotations
"stream": stream, # Incoming stream of examples
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment