Created
March 23, 2022 15:04
-
-
Save jcabot/873bf02c7b0919f2678ec34b376ccec3 to your computer and use it in GitHub Desktop.
Chatbot Data Processing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(bot: Bot): | |
for context in bot.contexts: | |
__train_context(context, bot.configuration) | |
def __train_context(context: NLUContext, configuration: NlpConfiguration): | |
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=configuration.num_words, lower=configuration.lower, oov_token=configuration.oov_token) | |
total_training_sentences: list[str] = [] | |
total_labels_training_sentences: list[int] = [] | |
for intent in context.intents: | |
preprocess_training_sentences(intent, configuration) | |
index_intent = context.intents.index(intent) | |
total_training_sentences.extend(intent.processed_training_sentences) | |
total_labels_training_sentences.extend([index_intent for i in range(len(intent.processed_training_sentences))]) | |
tokenizer.fit_on_texts(total_training_sentences) | |
context.tokenizer = tokenizer | |
context.training_sentences = total_training_sentences | |
context.training_sequences = tf.keras.preprocessing.sequence.pad_sequences(tokenizer.texts_to_sequences(total_training_sentences), | |
padding='post', maxlen=configuration.input_max_num_tokens) | |
context.training_labels = total_labels_training_sentences | |
# Definition of the NN model as shown before | |
history = model.fit(np.array(context.training_sequences), np.array(context.training_labels), epochs=configuration.num_epochs, verbose=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment