Created
May 16, 2018 23:37
-
-
Save jamesdunham/90765b567f1e98791f46ee6e41418cbd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import copy | |
import spacy | |
from spacy.tokens import Doc, Span | |
class Template(object): | |
"""Create synthetic NER training data from a template document. | |
Provide a template NER-annotated spacy Doc when instantiating the class. Passing text to the `render` method | |
populates the templated entity spans, preserving entity labels, and generates a new Doc. | |
Attributes | |
---------- | |
template : Doc | |
NER-annotated template document. | |
ents : list | |
Entities of the template document, replaced when rendered. A list of 4-tuples like `[(ent.text, ent.start, | |
end.end, ent.label)]`. | |
Methods | |
------- | |
update(template) | |
Replace the existing template. | |
render(substitutes) | |
Populate the template by replacing its entity spans with `substitutes`. | |
__len__() | |
Count of template entities. | |
__str__() | |
Template text. | |
Notes | |
----- | |
Any attributes of the template document and its component elements (e.g., tokens) are discarded, aside from IOB | |
tags and entity labels. | |
""" | |
def __init__(self, template: Doc): | |
self._blank_nlp = spacy.blank('en') | |
self.update(template) | |
def __str__(self): | |
return self.template.text | |
def __len__(self): | |
return len(self.ents) | |
def update(self, template: Doc): | |
self.template = template | |
assert len(template.ents) | |
self.ents = self._extract_ents(template) | |
def render(self, substitutes: list) -> Doc: | |
if len(self.ents) > len(substitutes): | |
raise ValueError('Need at least as many substitute entities as original entities') | |
substitutes = copy(substitutes) | |
# Build up output text | |
output_text = [] | |
output_spans = [] | |
output_idx = 0 | |
for token_idx, token in enumerate(self.template): | |
if token.ent_iob_ == 'B': | |
# Replace the first token of the entity span with its substitute | |
output_spans.append(( | |
output_idx, | |
output_idx + len(substitutes[0]), | |
token.ent_type_ | |
)) | |
# Make the substitution and advance the output index to match | |
output_idx += len(substitutes[0]) | |
output_text.append(substitutes.pop(0)) | |
if token.ent_iob_ in ['B', 'I'] and self.template[ | |
min(token_idx + 1, len(self.template) - 1)].ent_iob_ != 'I': | |
# This is the last token of the entity span, so add its whitespace | |
output_text.append(token.whitespace_) | |
output_idx += len(token.whitespace_) | |
if token.ent_iob_ in ['', 'O']: | |
# Pass non-entity tokens through | |
output_text.append(token.text_with_ws) | |
output_idx += len(token.text_with_ws) | |
output_doc = self._blank_nlp(''.join(output_text)) | |
self._add_ents(output_doc, output_spans) | |
return output_doc | |
@staticmethod | |
def _add_ents(doc, spans): | |
for start, end, label in spans: | |
span = doc.char_span(start, end, label=label) | |
if span: | |
doc.ents = list(doc.ents) + [span] | |
else: | |
print('Skipping invalid span!') | |
return doc | |
@staticmethod | |
def _extract_ents(doc): | |
return [(e.text, e.start_char, e.end_char, e.label_) for e in doc.ents] | |
def test_one_ent(): | |
nlp = spacy.load('en_core_web_sm') | |
doc = nlp(u'Cambridge is full of rabbits.') | |
template = Template(doc) | |
substitutes = ['New York'] | |
output_doc = template.render(substitutes) | |
assert isinstance(output_doc, Doc) | |
assert output_doc.text == 'New York is full of rabbits.' | |
assert output_doc.ents and isinstance(output_doc.ents[0], Span) | |
assert output_doc.ents[0].text == 'New York' | |
def test_two_ents(): | |
nlp = spacy.load('en_core_web_sm') | |
doc = nlp(u'New York or San Francisco?') | |
template = Template(doc) | |
substitutes = ['Boston', 'Philadelphia'] | |
output_doc = template.render(substitutes) | |
assert isinstance(output_doc, Doc) | |
assert output_doc.text == 'Boston or Philadelphia?' | |
assert len(output_doc.ents) == 2 | |
assert output_doc.ents[0].text == 'Boston' | |
assert output_doc.ents[1].text == 'Philadelphia' | |
if __name__ == '__main__': | |
test_one_ent() | |
test_two_ents() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment