Created
February 7, 2021 13:17
-
-
Save bumie-e/2874fcc5c4baa0b1d77daa2d78c28d09 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import json | |
import torch | |
from model import NeuralNet | |
from nltk_utils import bag_of_words, tokenize | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
with open('intents.json', 'r') as json_data: | |
intents = json.load(json_data) | |
FILE = "data.pth" | |
data = torch.load(FILE) | |
input_size = data["input_size"] | |
hidden_size = data["hidden_size"] | |
output_size = data["output_size"] | |
all_words = data['all_words'] | |
tags = data['tags'] | |
model_state = data["model_state"] | |
model = NeuralNet(input_size, hidden_size, output_size).to(device) | |
model.load_state_dict(model_state) | |
model.eval() | |
bot_name = "Sam" | |
print("Let's chat! (type 'quit' to exit)") | |
while True: | |
# sentence = "do you use credit cards?" | |
sentence = input("You: ") | |
if sentence == "quit": | |
break | |
sentence = tokenize(sentence) | |
X = bag_of_words(sentence, all_words) | |
X = X.reshape(1, X.shape[0]) | |
X = torch.from_numpy(X).to(device) | |
output = model(X) | |
_, predicted = torch.max(output, dim=1) | |
tag = tags[predicted.item()] | |
probs = torch.softmax(output, dim=1) | |
prob = probs[0][predicted.item()] | |
if prob.item() > 0.75: | |
for intent in intents['intents']: | |
if tag == intent["tag"]: | |
print(f"{bot_name}: {random.choice(intent['responses'])}") | |
else: | |
print(f"{bot_name}: I do not understand...") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment