Skip to content

Instantly share code, notes, and snippets.

@magdaaniol
Last active May 1, 2024 15:55
Show Gist options
  • Save magdaaniol/a0d09f62549eb4fb5c52001359f03c9a to your computer and use it in GitHub Desktop.
Save magdaaniol/a0d09f62549eb4fb5c52001359f03c9a to your computer and use it in GitHub Desktop.
A script to preprocess Prodigy JSONL stream by adding POS spans from a spaCy pipeline
import copy
import json
from typing import Dict, List, Optional, Tuple
import spacy
import srsly
from spacy.language import Language
from spacy.tokens import Doc, Span
from spacy.util import filter_spans
from wasabi import msg
from prodigy.components.stream import get_stream
from prodigy.types import StreamType
from prodigy.util import set_hashes
nlp = spacy.load("en_core_web_lg")
SKIPPED = 0
POS_SPANS = ["NOUN", "PROPN", "PRON", "DET"]
def is_entity(start: int, end: int, curr_spans: List[Tuple[int, int]]) -> bool:
for s, e in curr_spans:
if start <= e and end >= s:
return True
return False
def make_doc(tokens: List[Dict]) -> Doc:
words = [token["text"] for token in tokens]
spaces = [token.get("ws", True) for token in tokens]
doc = Doc(nlp.vocab, words=words, spaces=spaces)
return doc
def get_pos_spans(doc: Doc) -> List[Span]:
pos_spans = []
for i, token in enumerate(doc):
pos = token.pos_
if pos in POS_SPANS:
pos_spans.append(doc.char_span(token.idx, token.idx + len(token.text), pos))
return pos_spans
def get_nps_spans(doc: Doc) -> List[Span]:
nps_spans = []
for np in doc.noun_chunks:
nps_spans.append(doc.char_span(np.start_char, np.end_char, "NP"))
return nps_spans
def convert_to_prodigy_spans(spans: List[Span]) -> List[Dict]:
prodigy_spans = []
for span in spans:
prodigy_spans.append(
{
"token_start": span.start,
"token_end": span.end,
"start": span.start_char,
"end": span.end_char,
"text": span.text,
"label": span.label_,
"source": "en_core_web_lg",
}
)
return prodigy_spans
def align_with_curr_tokenization(
curr_doc: Optional[Doc], doc: Doc, spans: List[Dict]
) -> List[Dict]:
global SKIPPED
aligned_spans = []
reference_doc = curr_doc if curr_doc else doc
for s in spans:
valid_span = reference_doc.char_span(s["start"], s["end"], s["label"])
if valid_span:
aligned_spans.append(s)
else:
msg.warn(f"Skipping {json.dumps(s,indent=4)}")
SKIPPED += 1
return aligned_spans
def add_pos(nlp: Language, stream: StreamType) -> StreamType:
texts = ((eg["text"], eg) for eg in stream)
for doc, eg in nlp.pipe(texts, as_tuples=True, batch_size=10):
task = copy.deepcopy(eg)
curr_spans = eg.get("spans")
curr_tokens = eg.get("tokens")
if curr_tokens:
# if tokens are available, we'll reuse them for the final task
curr_doc = make_doc(curr_tokens)
curr_spans_indices = [(s["start"], s["end"]) for s in curr_spans]
# Add spans for relevant POS
pos_spans: List[Span] = get_pos_spans(doc)
# Add noun phrase spans
nps_spans: List[Span] = get_nps_spans(doc)
# Filter out POS spans that overlap with exisitng entity spans
only_pos_spans = [
span
for span in pos_spans + nps_spans
if not is_entity(span.start_char, span.end_char, curr_spans_indices)
]
# Make sure there are no overlapping spans between POS and NP spans
spacy_pos_spans = filter_spans(only_pos_spans)
# Convert to Prodigy span objects
prodigy_pos_spans = convert_to_prodigy_spans(spacy_pos_spans)
all_prodigy_spans = curr_spans + prodigy_pos_spans
# Make sure spans are aligned with the current tokens; skip the spans that are not
span_obj = align_with_curr_tokenization(curr_doc, doc, all_prodigy_spans)
task["spans"] = span_obj
task = set_hashes(task)
yield task
# Load dataset with NER annotations from Prodigy DB
stream = get_stream("dataset:ner-dataset")
stream = stream.apply(add_pos, nlp=nlp, stream=stream)
srsly.write_jsonl("ner_pos_dataset.jsonl", list(stream))
if SKIPPED > 0:
msg.warn(f"Skipped {SKIPPED} spans because of misaligned tokenization")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment