Last active
May 24, 2018 22:02
-
-
Save BalazsHoranyi/d37f0b1400a9f1cd2f373eb5f965e7ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
nlp = spacy.blank('en') # create blank Language class | |
print("Created blank 'en' model") | |
# add the text classifier to the pipeline if it doesn't exist | |
# nlp.create_pipe works for built-ins that are registered with spaCy | |
if 'textcat' not in nlp.pipe_names: | |
textcat = nlp.create_pipe('textcat') | |
nlp.add_pipe(textcat, last=True) | |
# otherwise, get it, so we can add labels to it | |
else: | |
textcat = nlp.get_pipe('textcat') | |
# add label to text classifier 0 is <=3 else 1 | |
for label in self.label_cols: | |
textcat.add_label(label) | |
print("Loading data...") | |
limit = None | |
(train_texts, train_cats), (dev_texts, dev_cats) = self.load_train(limit=None, split=.95) | |
print("Using {} examples ({} training, {} evaluation)" | |
.format(limit, len(train_texts), len(dev_texts))) | |
train_texts = [str(doc) for doc in train_texts] | |
dev_texts = [str(doc) for doc in dev_texts] | |
train_data = list(zip(train_texts, | |
[{'cats': cats} for cats in train_cats])) | |
n_iter = 10 | |
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat'] | |
with nlp.disable_pipes(*other_pipes): # only train textcat | |
optimizer = nlp.begin_training() | |
print("Training the model...") | |
print('{:^5}\t{:^5}\t{:^5}\t{:^5}'.format('LOSS', 'P', 'R', 'F')) | |
for i in range(n_iter): | |
losses = {} | |
# batch up the examples using spaCy's minibatch | |
batches = minibatch(train_data, size=compounding(4., 32., 1.001)) | |
for batch in batches: | |
texts, annotations = zip(*batch) | |
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, | |
losses=losses) | |
with textcat.model.use_params(optimizer.averages): | |
# evaluate on the dev data split off in load_data() | |
scores_eval = self.evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats) | |
scores_train = self.evaluate(nlp.tokenizer, textcat, train_texts, train_cats) | |
# should log these somewhere | |
print('Train {0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}' # print a simple table | |
.format(losses['textcat'], scores_train['textcat_p'], | |
scores_train['textcat_r'], scores_train['textcat_f'])) | |
print('Eval {0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}' # print a simple table | |
.format(losses['textcat'], scores_eval['textcat_p'], | |
scores_eval['textcat_r'], scores_eval['textcat_f'])) | |
# dump to disk so we can load it later. | |
nlp.to_disk('appclass') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment