Last active
June 13, 2020 15:33
-
-
Save lvngd/1b8f0405591ee47f9ddd8131f27a2db4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import unicode_literals, print_function | |
import random | |
import plac | |
import pickle | |
from pathlib import Path | |
import spacy | |
from spacy.util import minibatch, compounding | |
""" | |
Script to train a custom Named Entity Recognizer with Spacy. | |
https://spacy.io/usage/training | |
""" | |
@plac.annotations( | |
model=("Model name. Defaults to blank 'en' model.", "option", "m", str), | |
train_data=("File path to training data. Defaults to training_data.pickle in same directory.", "option", "d", str), | |
output_dir=("Optional output directory", "option", "o", Path), | |
number_iterations=("Number of training iterations. Defaults to 100", "option", "n", int), | |
) | |
def train_model(model=None,train_data='training_data.pickle',output_dir=None,number_iterations=100): | |
with open(train_data, 'rb') as data: | |
TRAIN_DATA = pickle.load(data) | |
if model is not None: | |
nlp = spacy.load(model) | |
else: | |
nlp = spacy.blank("en") | |
nlp.vocab.vectors.name = 'spacy_pretrained_vectors' | |
if "ner" not in nlp.pipe_names: | |
#if it's a blank model we have to add the ner pipeline | |
ner = nlp.create_pipe('ner') | |
nlp.add_pipe(ner, last=True) | |
else: | |
#need to get the ner pipeline so that we can add labels | |
ner = nlp.get_pipe("ner") | |
# add labels | |
for _, annotations in TRAIN_DATA: | |
for ent in annotations.get("entities"): | |
ner.add_label(ent[2]) | |
#if not using a blank model, need to disable all pipelines except ner | |
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"] | |
with nlp.disable_pipes(*other_pipes): | |
if model is None: | |
nlp.begin_training() | |
for itn in range(number_iterations): | |
random.shuffle(TRAIN_DATA) | |
losses = {} | |
batches = minibatch(TRAIN_DATA,size=compounding(4.0, 32.0, 1.001)) | |
for batch in batches: | |
texts,annotations = zip(*batch) | |
nlp.update( | |
texts, | |
annotations, | |
drop=0.5, | |
losses=losses | |
) | |
print("losses", losses) | |
if output_dir is not None: | |
output_dir = Path(output_dir) | |
if not output_dir.exists(): | |
output_dir.mkdir() | |
nlp.to_disk(output_dir) | |
print("Saved model to", output_dir) | |
print("Testing. Loading from", output_dir) | |
nlp2 = spacy.load(output_dir) | |
for text, _ in TRAIN_DATA: | |
doc = nlp2(text) | |
print("Entities", [(ent.text, ent.label_) for ent in doc.ents]) | |
print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc]) | |
if __name__=="__main__": | |
plac.call(train_model) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment