Skip to content

Instantly share code, notes, and snippets.

@jamesdunham
Created May 16, 2018 23:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesdunham/90765b567f1e98791f46ee6e41418cbd to your computer and use it in GitHub Desktop.
Save jamesdunham/90765b567f1e98791f46ee6e41418cbd to your computer and use it in GitHub Desktop.
from copy import copy
import spacy
from spacy.tokens import Doc, Span
class Template(object):
"""Create synthetic NER training data from a template document.
Provide a template NER-annotated spacy Doc when instantiating the class. Passing text to the `render` method
populates the templated entity spans, preserving entity labels, and generates a new Doc.
Attributes
----------
template : Doc
NER-annotated template document.
ents : list
Entities of the template document, replaced when rendered. A list of 4-tuples like `[(ent.text, ent.start,
end.end, ent.label)]`.
Methods
-------
update(template)
Replace the existing template.
render(substitutes)
Populate the template by replacing its entity spans with `substitutes`.
__len__()
Count of template entities.
__str__()
Template text.
Notes
-----
Any attributes of the template document and its component elements (e.g., tokens) are discarded, aside from IOB
tags and entity labels.
"""
def __init__(self, template: Doc):
self._blank_nlp = spacy.blank('en')
self.update(template)
def __str__(self):
return self.template.text
def __len__(self):
return len(self.ents)
def update(self, template: Doc):
self.template = template
assert len(template.ents)
self.ents = self._extract_ents(template)
def render(self, substitutes: list) -> Doc:
if len(self.ents) > len(substitutes):
raise ValueError('Need at least as many substitute entities as original entities')
substitutes = copy(substitutes)
# Build up output text
output_text = []
output_spans = []
output_idx = 0
for token_idx, token in enumerate(self.template):
if token.ent_iob_ == 'B':
# Replace the first token of the entity span with its substitute
output_spans.append((
output_idx,
output_idx + len(substitutes[0]),
token.ent_type_
))
# Make the substitution and advance the output index to match
output_idx += len(substitutes[0])
output_text.append(substitutes.pop(0))
if token.ent_iob_ in ['B', 'I'] and self.template[
min(token_idx + 1, len(self.template) - 1)].ent_iob_ != 'I':
# This is the last token of the entity span, so add its whitespace
output_text.append(token.whitespace_)
output_idx += len(token.whitespace_)
if token.ent_iob_ in ['', 'O']:
# Pass non-entity tokens through
output_text.append(token.text_with_ws)
output_idx += len(token.text_with_ws)
output_doc = self._blank_nlp(''.join(output_text))
self._add_ents(output_doc, output_spans)
return output_doc
@staticmethod
def _add_ents(doc, spans):
for start, end, label in spans:
span = doc.char_span(start, end, label=label)
if span:
doc.ents = list(doc.ents) + [span]
else:
print('Skipping invalid span!')
return doc
@staticmethod
def _extract_ents(doc):
return [(e.text, e.start_char, e.end_char, e.label_) for e in doc.ents]
def test_one_ent():
nlp = spacy.load('en_core_web_sm')
doc = nlp(u'Cambridge is full of rabbits.')
template = Template(doc)
substitutes = ['New York']
output_doc = template.render(substitutes)
assert isinstance(output_doc, Doc)
assert output_doc.text == 'New York is full of rabbits.'
assert output_doc.ents and isinstance(output_doc.ents[0], Span)
assert output_doc.ents[0].text == 'New York'
def test_two_ents():
nlp = spacy.load('en_core_web_sm')
doc = nlp(u'New York or San Francisco?')
template = Template(doc)
substitutes = ['Boston', 'Philadelphia']
output_doc = template.render(substitutes)
assert isinstance(output_doc, Doc)
assert output_doc.text == 'Boston or Philadelphia?'
assert len(output_doc.ents) == 2
assert output_doc.ents[0].text == 'Boston'
assert output_doc.ents[1].text == 'Philadelphia'
if __name__ == '__main__':
test_one_ent()
test_two_ents()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment