Skip to content

Instantly share code, notes, and snippets.

@compustar
Created January 26, 2023 07:39
Show Gist options
  • Save compustar/96c882ce891a030263555d5e254932e0 to your computer and use it in GitHub Desktop.
Save compustar/96c882ce891a030263555d5e254932e0 to your computer and use it in GitHub Desktop.
Snippets in employee_handbook_qna_gpt3.ipynb
from sentence_transformers import SentenceTransformer, util
import openai
with open('HKIHRMEmployeeHandbook.txt') as f:
full_text = f.read()
passages = [line.strip() for line in full_text.split('\n \n') if len(line.strip()) > 0]
model_name = 'multi-qa-mpnet-base-cos-v1'
bi_encoder = SentenceTransformer(model_name)
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
query = 'what should I do if I worked overtime?'
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = [hit for hit in hits[0]]
hits = sorted([hit['corpus_id'] for hit in hits])
context = "\n".join([passages[hit] for hit in hits])
template = """Context:
<<context>>
Answer the following question:
Q: <<query>>
A:
"""
prompt = template.replace('<<context>>', context).replace('<<query>>', query)
import openai
openai.api_key = input("OpenAI API Key: ")
response = openai.Completion.create(engine="text-davinci-003", prompt=prompt, max_tokens=256, temperature=0.2)
print(response['choices'][0]['text'])
def answer(query):
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = [hit for hit in hits[0]]
hits = sorted([hit['corpus_id'] for hit in hits])
context = "\n".join([passages[hit] for hit in hits])
template = """Context:
<<context>>
Answer the following question by paraphrasing it and then elaborate the answer:
Q: <<query>>
A:
"""
prompt = template.replace('<<context>>', context).replace('<<query>>', query)
prompt_length = len(tokenizer(prompt)['input_ids'])
response = openai.Completion.create(engine="text-davinci-003", prompt=prompt, max_tokens=4096-prompt_length, temperature=0.2)
return response['choices'][0]['text']
while True:
query = input("Q: ")
if query == "xxx": break
ans = answer(query)
print(f"A: {ans}")
print("="*70)
print()
import gradio as gr
examples = [
["can I carry forward annual leave?"],
["do I entitle to compensation leave?"],
["how many days of annual leave do I have?"],
["how to reimburse the company expenses?"],
["do I get paid if I got sick?"],
["what should I do if I got sick?"],
]
title = "Q&A Demo"
def inference(text):
return answer(text)
io = gr.Interface(
inference,
gr.Textbox(lines=3),
outputs=[
gr.Textbox(lines=3, label="GPT 3.5")
],
title=title,
examples=examples
)
io.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment