Skip to content

Instantly share code, notes, and snippets.

@vamsigutta
Last active July 24, 2023 06:39
Show Gist options
  • Save vamsigutta/e12747aa035c8d0fccb037cfd1fc3eea to your computer and use it in GitHub Desktop.
Save vamsigutta/e12747aa035c8d0fccb037cfd1fc3eea to your computer and use it in GitHub Desktop.
This code is can be used to create a basic QA chatbot
from sentence_transformers import SentenceTransformer, util
import torch
if __name__ == "__main__":
questions = ["hi", "What is the product", "what is the cost of the product"]
answers = ["Hi there, How can I help you", "The product is an QA chatbot", "The product is available for free"]
## Load the pretrained model
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
model.max_seq_length = 512
## Generate the embeddings
embeddings = model.encode(
questions,
show_progress_bar=True,
convert_to_tensor=True,
)
while True:
query = input("Enter query: ")
query_embeddings = model.encode(query, convert_to_tensor=True)
cosine_scores = util.cos_sim(query_embeddings, embeddings)
similarity_score = torch.max(cosine_scores)
if similarity_score > 0.75:
answer_idx = torch.argmax(cosine_scores)
result = answers[answer_idx]
else:
result = "Please refer to the documentation ..."
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment