Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save djuelg/44271df5942b1472a14dd06f16175ae6 to your computer and use it in GitHub Desktop.
Save djuelg/44271df5942b1472a14dd06f16175ae6 to your computer and use it in GitHub Desktop.
Generate Rasa NLU training data for custom entities
'''
Utility to create training data samples for custom NER for Rasa NLU, by providing a set of sentences with
a placeholder ($entity). This placeholder will be filled by provided values (dropins).
Tested using ner_crf with ner_regex
See: https://nlu.rasa.com/entities.html
In this sample a custom entity "security_term" is created for the intent "search_confluence"
'''
import random
import json
from typing import Set
from string import Template
'''
For every value an entity (called dropin) could be, pick random sentences to build a training data json from it
The sentences aren't strings! They are Templates, see https://docs.python.org/2.4/lib/node109.html
From a template $entity will be replaced with the actual dropin
E.g. "Search confluence for $entity" -> "Search confluence for SQL Injection"
'''
def create_custom_ner_entity(intent_name: str, entity_name: str, sentences: Set[Template], sentence_count: int , dropins: Set[str]):
ner_json_samples = set()
for dropin in dropins:
for sentence in random.sample(sentences, sentence_count):
ner_json_samples.add(create_json_from(sentence, dropin, intent_name, entity_name))
return ner_json_samples
def create_json_from(sentence, dropin, intent_name, entity_name):
start_index = sentence.substitute(entity="---findme!---").find("---findme!---")
end_index = start_index + len(dropin)
return json.dumps({
"text": sentence.substitute(entity=dropin),
"intent": intent_name,
"entities": [
{
"start": start_index,
"end": end_index,
"value": dropin,
"entity": entity_name
}
]
})
def print_inner_json_list(json_set: Set[str]):
for json in json_set:
print(json + ",")
'''
The referenced files contain one term/sentence per line
'''
def train_every_dropin_with_three_random_sentences():
dropins = [line.rstrip('\n') for line in open('security_terms')] # sample line: SQL Injection
string_sentences = [line.rstrip('\n') for line in open('search_sentences')] # sample line: Search confluence for $entity
template_sentences = set()
template_sentences.update([Template(s_sentence) for s_sentence in string_sentences])
print_inner_json_list(create_custom_ner_entity("search_confluence", "security_term", template_sentences, 3, dropins))
train_every_dropin_with_three_random_sentences()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment