-
-
Save ines/dd618b5bdc544b4ff49b363e98c6368a to your computer and use it in GitHub Desktop.
Experimental Prodigy recipes for spacy-pytorch-transformers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Very experimental (!) Prodigy recipes for text classification annotation with | |
transformer models. Requires Prodigy (https://prodi.gy) to be installed. | |
By taking advantage of transformer models like BERT and XLNet, we can train | |
a highly accurate text classifier using only a very small set of labelled | |
examples. Transformers also very large and slow and not always a good fit for | |
production. However, we can use them to supervise a smaller and more efficient | |
runtime model (e.g. spaCy's built-in text classifier). First, we can create | |
a small manually labelled set and use it to fine-tune the pretrained transformer | |
model. Next, we can use that model to create new training data semi- | |
automatically by running it over the data and only send the examples with scores | |
within a certain range out for annotation. Higher scores are saved to the | |
dataset automatically, lower scores are skipped. This should make it very | |
quick to create large volumes of training data. | |
Proposed workflow: | |
1. Use `pytt.textcat.manual` to create a small initial dataset for your label | |
scheme and your data. Expects all labels to be present and will use the | |
choice interface. (For very imbalanced classes, we could also consider | |
adding an option to create annotations using match patterns.) | |
2. Use `pytt.textcat.batch-train` with --use-transformer to add a text | |
classifier to a pre-trained transformer model and train it on the created | |
dataset. (Also compare it to the spaCy baseline to make sure the transformer | |
model has an advantage and is actually better: run the same recipe without | |
--use-transformer and a regular spaC model.) | |
3. Run `pytt.textcat.create-data` with a raw text source and the pre-trained | |
transformer model. It compares the highest score against the min_score and | |
threshold, skips scores lower than the min_score, auto-saves examples with | |
scores higher than the threshold and sends everything in between out for | |
annotation. | |
4. Use `pytt.textcat.batch-train` without --use-transformer to train a spaCy | |
model on the large dataset created in the previous step. | |
To see documentation for a recipe script, run: | |
$ prodigy [recipe name] -F prodigy_textcat.py --help | |
You probably also want to set PRODIGY_LOGGING=basic for additional logging info. | |
""" | |
import prodigy | |
from prodigy.components.loaders import JSONL | |
from prodigy.components.db import connect | |
from prodigy.util import split_string, log, set_hashes | |
import spacy | |
from spacy.util import minibatch, fix_random_seed | |
import random | |
import tqdm | |
from pathlib import Path | |
import srsly | |
import torch | |
import sys | |
from spacy_pytorch_transformers.util import warmup_linear_rates | |
RANDOM_SEED = 0 | |
@prodigy.recipe( | |
"pytt.textcat.manual", | |
dataset=("Name of dataset to save annotations", "positional", None, str), | |
source=("JSONL source data to load in", "positional", None, str), | |
label_set=("Label(s) to annotate", "option", "l", split_string), | |
) | |
def pytt_textcat_manual(dataset, source, label_set): | |
""" | |
Manually label the first batch of training examples. Those examples can | |
then be used to pre-train a transformer model to then label examples | |
semi-automatically. | |
""" | |
if len(label_set) < 2: | |
raise ValueError( | |
"Currently need at least two labels. If you're doing binary " | |
"classification, use a 'OTHER' or 'NOT_XXX' category." | |
) | |
def get_stream(stream): | |
options = [{"id": label, "text": label} for label in label_set] | |
for eg in stream: | |
eg["options"] = options | |
yield eg | |
stream = JSONL(source) | |
stream = get_stream(stream) | |
return { | |
"dataset": dataset, | |
"stream": stream, | |
"view_id": "choice", | |
"config": {"choice_style": "single", "choice_auto_accept": True}, | |
} | |
@prodigy.recipe( | |
"pytt.textcat.batch-train", | |
dataset=("Name of dataset to save annotations", "positional", None, str), | |
spacy_model=("Pre-trained transformer model package", "positional", None, str), | |
label_set=("Full label set to initialize model with", "option", "ls", split_string), | |
eval_id=("Name of evaluation dataset", "option", "e", str), | |
output=("Output directory", "option", "o", Path), | |
drop=("Dropout rate", "option", "d", float), | |
learn_rate=("Learning rate", "option", "lr", float), | |
batch_size=("Batch size", "option", "b", int), | |
n_iter=("Number of training epochs", "option", "i", int), | |
use_transformer=("Use transformer model", "flag", "T", bool), | |
) | |
def pytt_textcat_batch_train( | |
dataset, | |
spacy_model, | |
label_set, | |
eval_id, | |
output=None, | |
drop=0.1, | |
learn_rate=2e-5, | |
batch_size=2, | |
n_iter=10, | |
use_transformer=False, | |
): | |
""" | |
Add and train a text classifier. Expects data in Prodigy's "choice" format, | |
e.g. a list of top-level "options" matching the label set and an "accept" | |
list containing the selected label. All annotations should be complete and | |
contain no missing values. | |
""" | |
is_using_gpu = spacy.prefer_gpu() | |
if is_using_gpu: | |
torch.set_default_tensor_type("torch.cuda.FloatTensor") | |
fix_random_seed(RANDOM_SEED) | |
torch.manual_seed(RANDOM_SEED) | |
db = connect() | |
if output is not None and not output.exists(): | |
output.mkdir(parents=True) | |
print(f"Created output directory '{output}'") | |
nlp = spacy.load(spacy_model) | |
log(f"RECIPE: Loaded model '{spacy_model}'") | |
pipe_name = "pytt_textcat" if use_transformer else "textcat" | |
config = {"exclusive_classes": True} | |
if use_transformer: | |
config["architecture"] = "softmax_pooler_output" | |
textcat = nlp.create_pipe(pipe_name, config=config) | |
log(f"RECIPE: Added '{pipe_name}' to pipeline (use_transformer: {use_transformer})") | |
for label in label_set: | |
textcat.add_label(label) | |
nlp.add_pipe(textcat, last=True) | |
examples = db.get_dataset(dataset) | |
train_data = format_data(examples, label_set) | |
log(f"RECIPE: Using {len(train_data)} training examples from dataset '{dataset}'") | |
dev_examples = db.get_dataset(eval_id) | |
dev_data = format_data(dev_examples, label_set) | |
log(f"RECIPE: Using {len(dev_data)} evaluation examples from dataset '{eval_id}'") | |
if output is not None: | |
srsly.write_jsonl(output / "training_prodigy.jsonl", examples) | |
srsly.write_jsonl(output / "training_spacy.jsonl", train_data) | |
srsly.write_jsonl(output / "evaluation_prodigy.jsonl", dev_examples) | |
srsly.write_jsonl(output / "evaluation_spacy.jsonl", dev_data) | |
if use_transformer: | |
optimizer = nlp.resume_training() | |
learn_rates = warmup_linear_rates( | |
learn_rate, 0, int(n_iter * len(train_data) / batch_size) | |
) | |
optimizer.pytt_use_swa = True | |
optimizer.alpha = learn_rate | |
disabled = None | |
else: | |
other_pipes = [p for p in nlp.pipe_names if p not in ("sentencizer", "textcat")] | |
disabled = nlp.disable_pipes(*other_pipes) | |
optimizer = nlp.begin_training() | |
learn_rates = None | |
dev_texts, dev_cats = zip(*dev_data) | |
print("Training the model...") | |
print("{:^5}\t{:^5}\t{:^5}\t{:^5}\t{:^5}".format("#", "LOSS", "P", "R", "F")) | |
for i in range(n_iter): | |
losses = {} | |
random.shuffle(train_data) | |
batches = minibatch(train_data, size=batch_size) | |
try: | |
with tqdm.tqdm(desc="TRAIN", total=len(train_data), leave=False) as pbar: | |
for batch in batches: | |
if use_transformer: | |
optimizer.pytt_lr = next(learn_rates) | |
texts, annots = zip(*batch) | |
nlp.update(texts, annots, sgd=optimizer, drop=drop, losses=losses) | |
pbar.update(len(texts)) | |
scores = evaluate(nlp, dev_texts, dev_cats) | |
except KeyboardInterrupt: | |
print("Stopped training.") | |
sys.exit(1) | |
print( | |
"{0:5}\t{1:.3f}\t{2:.3f}\t{3:.3f}\t{4:.3f}".format( | |
i, | |
losses[pipe_name], | |
scores["textcat_p"], | |
scores["textcat_r"], | |
scores["textcat_f"], | |
) | |
) | |
if disabled is not None: | |
disabled.restore() | |
if output is not None: | |
nlp.to_disk(output / f"model_{i}") | |
if output is not None: | |
print(f"Saved model data to {output}") | |
@prodigy.recipe( | |
"textcat.pytt.create-data", | |
dataset=("Name of dataset to save annotations", "positional", None, str), | |
spacy_model=("Pre-trained transformer model package", "positional", None, str), | |
source=("JSONL source data to load in", "positional", None, str), | |
min_score=("Minimum score for example to be considered", "option", "s", float), | |
threshold=("Threshold score for skipping human annotation", "option", "t", float), | |
) | |
def textcat_pytt_create_data( | |
dataset, spacy_model, source, min_score=0.6, threshold=0.9 | |
): | |
""" | |
Create training data by using a transformer model pre-trained with | |
textcat.pytt.batch-train. Ideally, the model should only need very very | |
few examples to make accurate predictions. You can then use it to create | |
large volumes of training data for a more efficient runtime model (e.g. | |
spaCy's built-in text classifier). The incoming examples are processed | |
using the transformer model. | |
* Scores lower than the min_score are skipped. | |
* Scores lower than the threshold are sent out for human annotation. | |
* Scores higher than the threshold are saved to the dataset automatically. | |
""" | |
n_skipped = 0 | |
n_annotated = 0 | |
n_autosaved = 0 | |
autosave_cache = [] | |
db = connect() | |
if dataset not in db: | |
db.add_dataset(dataset) | |
nlp = spacy.load(spacy_model) | |
if "pytt_textcat" not in nlp.pipe_names: | |
raise ValueError("Need pre-trained text classifier") | |
def autosave(): | |
nonlocal autosave_cache | |
db.add_examples(autosave_cache, datasets=[dataset]) | |
autosave_cache = [] | |
def get_stream(stream): | |
nonlocal n_skipped, n_annotated, n_autosaved | |
tpl = "{} <span style='font-size: 0.6em; padding-left: 0.75em'>{:.3f}</span>" | |
data_tuples = ((eg["text"], eg) for eg in stream) | |
for doc, eg in nlp.pipe(data_tuples, as_tuples=True): | |
label, score = max(doc.cats.items(), key=lambda c: c[1]) | |
options = [{"id": l, "html": tpl.format(l, s)} for l, s in doc.cats.items()] | |
eg["options"] = options | |
eg["accept"] = [label] | |
if score < min_score: # score too low | |
continue | |
n_skipped += 1 | |
elif score < threshold: # score needs human feedback | |
yield eg | |
n_annotated += 1 | |
else: # score is okay | |
eg["answer"] = "accept" | |
eg = set_hashes(eg) | |
autosave_cache.append(eg) | |
n_autosaved += 1 | |
if n_skipped and n_skipped % 10 == 0: | |
log(f"RECIPE: Skipped {n_skipped} examples") | |
if n_autosaved and n_autosaved % 10 == 0: | |
log(f"RECIPE: Autosaved {n_autosaved} examples") | |
autosave() | |
def on_exit(ctrl): | |
autosave() | |
print(f"Skipped (score < {min_score})", n_skipped) | |
print(f"Annotated (score < {threshold})", n_annotated) | |
print(f"Autosaved (score >= {threshold})", n_autosaved) | |
stream = JSONL(source) | |
stream = get_stream(stream) | |
return { | |
"dataset": dataset, | |
"stream": stream, | |
"on_exit": on_exit, | |
"view_id": "choice", | |
} | |
def format_data(examples, label_set): | |
data = [] | |
examples = [eg for eg in examples if eg["answer"] == "accept"] | |
for eg in examples: | |
option_labels = [opt.get("id") for opt in eg.get("options", [])] | |
if sorted(option_labels) != sorted(label_set): | |
raise ValueError( | |
"Expected data with 'options' matching the label set but got:\n" | |
f"Options: {option_labels}\nLabel set: {label_set}" | |
) | |
selected = eg.get("accept", []) | |
if not selected: | |
continue | |
cats = {la: 1.0 if la in selected else 0.0 for la in option_labels} | |
data.append((eg["text"], {"cats": cats})) | |
return data | |
def evaluate(nlp, texts, cats): | |
tp = 0.0 # True positives | |
fp = 0.0 # False positives | |
fn = 0.0 # False negatives | |
tn = 0.0 # True negatives | |
right = 0 | |
wrong = 0 | |
with tqdm.tqdm(desc=" EVAL", total=len(texts), leave=False) as pbar: | |
for i, doc in enumerate(nlp.pipe(texts, batch_size=8)): | |
gold = cats[i]["cats"] | |
score, guess = max((score, cat) for cat, score in doc.cats.items()) | |
_, truth = max((score, cat) for cat, score in gold.items()) | |
if guess == truth: | |
right += 1 | |
else: | |
wrong += 1 | |
for label, score in doc.cats.items(): | |
if label not in gold: | |
continue | |
if score >= 0.5 and gold[label] >= 0.5: | |
tp += 1.0 | |
elif score >= 0.5 and gold[label] < 0.5: | |
fp += 1.0 | |
elif score < 0.5 and gold[label] < 0.5: | |
tn += 1 | |
elif score < 0.5 and gold[label] >= 0.5: | |
fn += 1 | |
pbar.update(1) | |
precision = tp / (tp + fp + 1e-8) | |
recall = tp / (tp + fn + 1e-8) | |
if (precision + recall) == 0: | |
f_score = 0.0 | |
else: | |
f_score = 2 * (precision * recall) / (precision + recall) | |
return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment