Last active
July 17, 2020 16:39
-
-
Save mariastefan/e664b279e735916d8e196467769874b5 to your computer and use it in GitHub Desktop.
train new tagger after custom tokenizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from pathlib import Path | |
import spacy | |
import sys | |
import os | |
sys.path.append('.') | |
from resolution_coreferences_pronominales.custom_model_training.custom_tokenizer import nlp_loader | |
output_dir = os.path.abspath(os.path.dirname(__file__)) + '/customPOS/' | |
# base_model = 'fr_core_news_sm' | |
base_model = 'customTokenizerModel/' | |
# meta.json of the new model | |
lang = 'fr' | |
name = 'custom_sm' | |
description = 'Custom model based on fr_core_news_sm : French multi-task CNN trained on the ' \ | |
'French Sequoia (Universal Dependencies) and ' \ | |
'WikiNER corpus. Assigns context-specific token vectors, POS tags, dependency ' \ | |
'parse and named entities. Supports identification of PER, LOC, ORG and MISC ' \ | |
'entities.' | |
version = '0.0.0' | |
TRAIN_DATA = [ | |
('Adrien voudrait plus de gateau. Il est culotté celui-là.', | |
{'tags': ['PROPN', 'VERB', 'ADV', 'ADP', 'NOUN', 'PUNCT', 'PRON', 'VERB', 'ADJ', 'PRON','PUNCT']}) | |
] | |
def train_tagger(model='fr_core_news_sm', output=None, n_iter=25): | |
# Loading the model with custom tokenizer | |
nlp = nlp_loader() | |
# nlp = spacy.load(base_model) | |
# Training the custom model tagger starting with the existing 'fr_core_news_sm' tagger | |
nlp.vocab.vectors.name = 'spacy_pretrained_vectors' | |
optimizer = nlp.begin_training() | |
for i in range(n_iter): | |
random.shuffle(TRAIN_DATA) | |
losses = {} | |
for text, annotations in TRAIN_DATA: | |
nlp.update([text], [annotations], sgd=optimizer, losses=losses) | |
print(losses) | |
# Temporary ! Testing the trained model with phrases from a file | |
# ------------------------------------------------------------------------------------------ # | |
print('\nTesting the trained model :') | |
text1 = 'Il est culotté celui-là.' | |
text2 = 'Il est culotté celui-ci.' | |
doc1 = nlp(text1) | |
doc2 = nlp(text2) | |
print('"' + doc1.text + '"') | |
print([[token.text, token.pos_] for token in doc1]) | |
print('"' + doc2.text + '"') | |
print([[token.text, token.pos_] for token in doc2]) | |
# ------------------------------------------------------------------------------------------ # | |
# save model to output directory | |
if output is not None: | |
output = Path(output) | |
if not output.exists(): | |
output.mkdir() | |
nlp.meta['lang'] = lang | |
nlp.meta['name'] = name | |
nlp.meta['description'] = description | |
nlp.meta['version'] = version | |
nlp.to_disk(output) | |
print('\nSaved model to', output) | |
# test the saved model | |
# Temporary ! Testing the trained model with phrases from a file | |
# ------------------------------------------------------------------------------------------ # | |
print('Loading from', output, '\n') | |
print('Testing the saved trained model :') | |
nlp2 = spacy.load(output) | |
text1 = 'Il est culotté celui-là.' | |
text2 = 'Il est culotté celui-ci.' | |
doc1 = nlp2(text1) | |
doc2 = nlp2(text2) | |
print('"' + doc1.text + '"') | |
print([[token.text, token.pos_] for token in doc1]) | |
print('"' + doc2.text + '"') | |
print([[token.text, token.pos_] for token in doc2]) | |
# ------------------------------------------------------------------------------------------ # | |
if __name__ == '__main__': | |
train_tagger(base_model, output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import fr_core_news_sm | |
import os | |
from spacy.matcher import Matcher | |
import json | |
from spacy.language import Language | |
from spacy.tokens import Doc | |
import spacy | |
json_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + \ | |
'/custom_model_training/custom_model_params/compound_words.json' | |
output_dir = os.path.abspath(os.path.dirname(__file__)) + '/customTokenizerModel/' | |
lang = 'fr' | |
name = 'custom_tokenizer_sm' | |
description = 'Custom model tokenizer model based on fr_core_news_sm : French multi-task CNN trained on the ' \ | |
'French Sequoia (Universal Dependencies) and ' \ | |
'WikiNER corpus. Assigns context-specific token vectors, POS tags, dependency ' \ | |
'parse and named entities. Supports identification of PER, LOC, ORG and MISC ' \ | |
'entities.' | |
version = '0.0.0' | |
def nlp_loader(): | |
""" | |
Temporary fonction allowing to load the nlp with the custom tokenizer. | |
This will later become a fonction creating a new model and scripts will no longer be loading the model with | |
this fonction but directly from the new customized model | |
:return: nlp | |
""" | |
nlp = fr_core_news_sm.load() | |
class CompoundWordsMerger: | |
def __init__(self, words_path): | |
# self.model_size = model_size | |
self.words_path = words_path | |
def __call__(self, doc: Doc): | |
# Adding hyphen compound words to the matcher | |
matcher = Matcher(nlp.vocab) | |
matcher.add('HYPHENS', None, [{'IS_ALPHA': True}, {'TEXT': '-'}, {'IS_ALPHA': True}]) | |
# Opening the json file containing the information about our custom compound words | |
with open(self.words_path) as json_file: | |
compound_words = json.load(json_file) | |
# Creating a list which will contain the keys of the dictionary in words_path json file | |
# These keys correspond to the custom compound words text | |
custom_exceptions_list = [] | |
for key in compound_words.keys(): | |
custom_exceptions_list.append(key) | |
# Adding the custom compound words from the json file to the matcher | |
for word in custom_exceptions_list: | |
pattern = [] | |
for word in word.split(' '): | |
pattern.append({'TEXT': word}) | |
matcher.add(word, None, pattern) | |
# Adding the matches containing the compound words to the doc | |
matched_spans = [] | |
matches = matcher(doc) | |
for match_id, start, end in matches: | |
span = doc[start:end] | |
matched_spans.append(span) | |
if str(span) in compound_words.keys(): | |
nlp.tokenizer.add_special_case(str(span), | |
[{'ORTH': str(span), 'POS': compound_words[str(span)]["pos"]}]) | |
for span in matched_spans: # merge into one token after collecting all matches | |
span.merge() | |
# Adding the custom lemmas for the custom compound words | |
for token in doc: | |
if ' ' in token.text: | |
if token.text in compound_words.keys(): | |
token.lemma_ = compound_words[token.text]["lemma"] | |
return doc | |
nlp.add_pipe(CompoundWordsMerger(json_path), | |
first=True) # , first=True : add it right after the tokenizer; default : last | |
# Adding the custom pipeline to the factories | |
Language.factories['CompoundWordsMerger'] = lambda _: CompoundWordsMerger(json_path) | |
nlp.meta['lang'] = lang | |
nlp.meta['name'] = name | |
nlp.meta['description'] = description | |
nlp.meta['version'] = version | |
nlp.to_disk(output_dir) | |
print("Saved model to", output_dir) | |
# ------------------------- temporary test ------------------------------------------------ | |
if __name__ == '__main__': | |
nlp2 = spacy.load(output_dir) | |
text = 'Il est beau celui-là. Les intelligences artificielles sont méchantes.' | |
doc = nlp2(text) | |
for token in doc: | |
print(token) | |
return nlp | |
if __name__ == '__main__': | |
nlp_loader() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is the output when I tag 'celui-là' with 'PRON','PUNCT','PRON' (rather than what I want to achieve : just 'PRON') | |
Saved model to /home/maria/Documents/resolution-des-coreferences-pronominales/resolution_coreferences_pronominales/custom_model_training/customTokenizerModel/ | |
{'parser': 0.0, 'tagger': 0.0, 'ner': 0.0} | |
{'ner': 0.0, 'parser': 0.0, 'tagger': 0.0} | |
{'parser': 0.0, 'ner': 0.0, 'tagger': 0.0} | |
{'tagger': 0.0, 'parser': 0.0, 'ner': 0.0} | |
{'tagger': 0.0, 'parser': 0.0, 'ner': 0.0} | |
{'tagger': 0.0, 'parser': 0.0, 'ner': 0.0} | |
{'parser': 0.0, 'ner': 0.0, 'tagger': 0.0} | |
{'parser': 0.0, 'tagger': 0.0, 'ner': 0.0} | |
{'tagger': 0.0, 'ner': 0.0, 'parser': 0.0} | |
{'tagger': 0.0, 'ner': 0.0, 'parser': 0.0} | |
{'ner': 0.0, 'parser': 0.0, 'tagger': 0.0} | |
{'tagger': 0.0, 'parser': 0.0, 'ner': 0.0} | |
{'parser': 0.0, 'tagger': 0.0, 'ner': 0.0} | |
{'ner': 0.0, 'parser': 0.0, 'tagger': 0.0} | |
{'tagger': 0.0, 'ner': 0.0, 'parser': 0.0} | |
{'parser': 0.0, 'ner': 0.0, 'tagger': 0.0} | |
{'ner': 0.0, 'parser': 0.0, 'tagger': 0.0} | |
{'parser': 0.0, 'ner': 0.0, 'tagger': 0.0} | |
{'tagger': 0.0, 'ner': 0.0, 'parser': 0.0} | |
{'tagger': 0.0, 'parser': 0.0, 'ner': 0.0} | |
{'ner': 0.0, 'tagger': 0.0, 'parser': 0.0} | |
{'parser': 0.0, 'ner': 0.0, 'tagger': 0.0} | |
{'parser': 0.0, 'tagger': 0.0, 'ner': 0.0} | |
{'ner': 0.0, 'parser': 0.0, 'tagger': 0.0} | |
{'ner': 0.0, 'tagger': 0.0, 'parser': 0.0} | |
Testing the trained model : | |
"Il est culotté celui-là." | |
[['Il', 'PRON'], ['est', 'AUX'], ['culotté', 'VERB'], ['celui-là', 'ADJ'], ['.', 'PUNCT']] | |
"Il est culotté celui-ci." | |
[['Il', 'PRON'], ['est', 'AUX'], ['culotté', 'VERB'], ['celui-ci', 'PRON'], ['.', 'PUNCT']] | |
Saved model to /home/maria/Documents/resolution-des-coreferences-pronominales/resolution_coreferences_pronominales/custom_model_training/customPOS | |
Loading from /home/maria/Documents/resolution-des-coreferences-pronominales/resolution_coreferences_pronominales/custom_model_training/customPOS | |
Testing the saved trained model : | |
"Il est culotté celui-là." | |
[['Il', 'PRON'], ['est', 'AUX'], ['culotté', 'VERB'], ['celui-là', 'ADJ'], ['.', 'PUNCT']] | |
"Il est culotté celui-ci." | |
[['Il', 'PRON'], ['est', 'AUX'], ['culotté', 'VERB'], ['celui-ci', 'PRON'], ['.', 'PUNCT']] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Output when trying to tag 'celui-là' with 'PRON' | |
Traceback (most recent call last): | |
File "........./train_new_tagger.py", line 89, in <module> | |
train_tagger(base_model, output_dir) | |
File "........./train_new_tagger.py", line 43, in train_tagger | |
nlp.update([text], [annotations], sgd=optimizer, losses=losses) | |
File "........./spacy/language.py", line 496, in update | |
docs, golds = self._format_docs_and_golds(docs, golds) | |
File "........./spacy/language.py", line 468, in _format_docs_and_golds | |
gold = GoldParse(doc, **gold) | |
File "gold.pyx", line 801, in spacy.gold.GoldParse.__init__ | |
IndexError: list index out of range |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment