Last active
May 1, 2024 15:55
-
-
Save magdaaniol/a0d09f62549eb4fb5c52001359f03c9a to your computer and use it in GitHub Desktop.
A script to preprocess Prodigy JSONL stream by adding POS spans from a spaCy pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import json | |
from typing import Dict, List, Optional, Tuple | |
import spacy | |
import srsly | |
from spacy.language import Language | |
from spacy.tokens import Doc, Span | |
from spacy.util import filter_spans | |
from wasabi import msg | |
from prodigy.components.stream import get_stream | |
from prodigy.types import StreamType | |
from prodigy.util import set_hashes | |
nlp = spacy.load("en_core_web_lg") | |
SKIPPED = 0 | |
POS_SPANS = ["NOUN", "PROPN", "PRON", "DET"] | |
def is_entity(start: int, end: int, curr_spans: List[Tuple[int, int]]) -> bool: | |
for s, e in curr_spans: | |
if start <= e and end >= s: | |
return True | |
return False | |
def make_doc(tokens: List[Dict]) -> Doc: | |
words = [token["text"] for token in tokens] | |
spaces = [token.get("ws", True) for token in tokens] | |
doc = Doc(nlp.vocab, words=words, spaces=spaces) | |
return doc | |
def get_pos_spans(doc: Doc) -> List[Span]: | |
pos_spans = [] | |
for i, token in enumerate(doc): | |
pos = token.pos_ | |
if pos in POS_SPANS: | |
pos_spans.append(doc.char_span(token.idx, token.idx + len(token.text), pos)) | |
return pos_spans | |
def get_nps_spans(doc: Doc) -> List[Span]: | |
nps_spans = [] | |
for np in doc.noun_chunks: | |
nps_spans.append(doc.char_span(np.start_char, np.end_char, "NP")) | |
return nps_spans | |
def convert_to_prodigy_spans(spans: List[Span]) -> List[Dict]: | |
prodigy_spans = [] | |
for span in spans: | |
prodigy_spans.append( | |
{ | |
"token_start": span.start, | |
"token_end": span.end, | |
"start": span.start_char, | |
"end": span.end_char, | |
"text": span.text, | |
"label": span.label_, | |
"source": "en_core_web_lg", | |
} | |
) | |
return prodigy_spans | |
def align_with_curr_tokenization( | |
curr_doc: Optional[Doc], doc: Doc, spans: List[Dict] | |
) -> List[Dict]: | |
global SKIPPED | |
aligned_spans = [] | |
reference_doc = curr_doc if curr_doc else doc | |
for s in spans: | |
valid_span = reference_doc.char_span(s["start"], s["end"], s["label"]) | |
if valid_span: | |
aligned_spans.append(s) | |
else: | |
msg.warn(f"Skipping {json.dumps(s,indent=4)}") | |
SKIPPED += 1 | |
return aligned_spans | |
def add_pos(nlp: Language, stream: StreamType) -> StreamType: | |
texts = ((eg["text"], eg) for eg in stream) | |
for doc, eg in nlp.pipe(texts, as_tuples=True, batch_size=10): | |
task = copy.deepcopy(eg) | |
curr_spans = eg.get("spans") | |
curr_tokens = eg.get("tokens") | |
if curr_tokens: | |
# if tokens are available, we'll reuse them for the final task | |
curr_doc = make_doc(curr_tokens) | |
curr_spans_indices = [(s["start"], s["end"]) for s in curr_spans] | |
# Add spans for relevant POS | |
pos_spans: List[Span] = get_pos_spans(doc) | |
# Add noun phrase spans | |
nps_spans: List[Span] = get_nps_spans(doc) | |
# Filter out POS spans that overlap with exisitng entity spans | |
only_pos_spans = [ | |
span | |
for span in pos_spans + nps_spans | |
if not is_entity(span.start_char, span.end_char, curr_spans_indices) | |
] | |
# Make sure there are no overlapping spans between POS and NP spans | |
spacy_pos_spans = filter_spans(only_pos_spans) | |
# Convert to Prodigy span objects | |
prodigy_pos_spans = convert_to_prodigy_spans(spacy_pos_spans) | |
all_prodigy_spans = curr_spans + prodigy_pos_spans | |
# Make sure spans are aligned with the current tokens; skip the spans that are not | |
span_obj = align_with_curr_tokenization(curr_doc, doc, all_prodigy_spans) | |
task["spans"] = span_obj | |
task = set_hashes(task) | |
yield task | |
# Load dataset with NER annotations from Prodigy DB | |
stream = get_stream("dataset:ner-dataset") | |
stream = stream.apply(add_pos, nlp=nlp, stream=stream) | |
srsly.write_jsonl("ner_pos_dataset.jsonl", list(stream)) | |
if SKIPPED > 0: | |
msg.warn(f"Skipped {SKIPPED} spans because of misaligned tokenization") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment