Last active
July 24, 2023 06:39
-
-
Save vamsigutta/e12747aa035c8d0fccb037cfd1fc3eea to your computer and use it in GitHub Desktop.
This code is can be used to create a basic QA chatbot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer, util | |
import torch | |
if __name__ == "__main__": | |
questions = ["hi", "What is the product", "what is the cost of the product"] | |
answers = ["Hi there, How can I help you", "The product is an QA chatbot", "The product is available for free"] | |
## Load the pretrained model | |
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
model.max_seq_length = 512 | |
## Generate the embeddings | |
embeddings = model.encode( | |
questions, | |
show_progress_bar=True, | |
convert_to_tensor=True, | |
) | |
while True: | |
query = input("Enter query: ") | |
query_embeddings = model.encode(query, convert_to_tensor=True) | |
cosine_scores = util.cos_sim(query_embeddings, embeddings) | |
similarity_score = torch.max(cosine_scores) | |
if similarity_score > 0.75: | |
answer_idx = torch.argmax(cosine_scores) | |
result = answers[answer_idx] | |
else: | |
result = "Please refer to the documentation ..." | |
print(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment