Skip to content

Instantly share code, notes, and snippets.

@pamelafox
Created November 7, 2023 20:21
Show Gist options
  • Save pamelafox/a3fdea186b687509c02cb186ca203328 to your computer and use it in GitHub Desktop.
Save pamelafox/a3fdea186b687509c02cb186ca203328 to your computer and use it in GitHub Desktop.
Chat approach with additional function call
import json
import logging
import re
from typing import Any, AsyncGenerator, Optional, Union
import aiohttp
import openai
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import QueryType
from approaches.approach import Approach
from core.messagebuilder import MessageBuilder
from core.modelhelper import get_token_limit
from text import nonewlines
from github_issues import query_df
class ChatReadRetrieveReadApproach(Approach):
# Chat roles
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
NO_RESPONSE = "0"
"""
Simple retrieve-then-read implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves
top documents from search, then constructs a prompt with them, and then uses OpenAI to generate an completion
(answer) with that prompt.
"""
system_message_chat_conversation = """Assistant helps the software engineers to understand and troubleshoot issues with various systems Be brief in your answers. If someone asks questions on how can you help me, share the functions that you can perform.
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
For tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question.
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, for example [info1.txt]. Don't combine sources, list each source separately, for example [info1.txt][info2.pdf].
{follow_up_questions_prompt}
{injected_prompt}
"""
follow_up_questions_prompt_content = """Generate three very brief follow-up questions that the user would likely ask next is about clarifying the context.
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for this access type?>>.
Try not to repeat questions that have already been asked.
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'"""
query_prompt_template = """Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base about employee healthcare plans and the employee handbook.
You have access to Azure Cognitive Search index with 100's of documents.
Generate a search query based on the conversation and the new question.
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
Do not include any text inside [] or <<>> in the search query terms.
Do not include any special characters like '+'.
If the question is not in English, translate the question to English before generating the search query.
If you cannot generate a search query, return just the number 0.
"""
query_prompt_few_shots = [
{"role": USER, "content": "How to debug aml compute auth issues?"},
{"role": ASSISTANT, "content": "Show steps to debug"},
{"role": USER, "content": "Does training job support custom identity?"},
{"role": ASSISTANT, "content": "Say 'Yes' or 'No', then show the details"},
]
git_hub_user_prompts = {
"top_issues" : "Please format this into html table with indexes starting from number 1 and show the date in Pacific date time format, donot use source or citation",
"specific_issue" : "Donot use html table format for this, start with the 'Title:' of the Issue, Then 'Open date:', followed by the technical 'Summary:' of the issue, format it nicely with each section having header"
}
def __init__(
self,
search_client: SearchClient,
openai_host: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
chatgpt_model: str,
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
):
self.search_client = search_client
self.openai_host = openai_host
self.chatgpt_deployment = chatgpt_deployment
self.chatgpt_model = chatgpt_model
self.embedding_deployment = embedding_deployment
self.embedding_model = embedding_model
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.query_language = query_language
self.query_speller = query_speller
self.chatgpt_token_limit = get_token_limit(chatgpt_model)
self.is_streaming = True
async def run_until_final_call(
self,
history: list[dict[str, str]],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: bool = False,
) -> tuple:
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
top = overrides.get("top", 3)
filter = self.build_filter(overrides, auth_claims)
original_user_query = history[-1]["content"]
user_query_request = "Generate search query for: " + original_user_query
functions = [
{
"name": "search_sources",
"description": "Retrieve sources from the Azure Cognitive Search index",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to retrieve documents from azure search eg: 'How to debug compute issues'",
}
},
"required": ["search_query"],
},
},
{
"name": "github_issues",
"description": "Retrieve Azure SDK related issues reported by users",
"parameters": {
"type": "object",
"properties": {
"top_issues" : {
"type": "string",
"description": "Set 'true' or 'false' based on if the question is on show the top issues eg: set it to 'true' if the question is 'What are the top issues'"
},
"specific_issue" : {
"type": "string",
"description": "If the ask is on show details of an issue with a number , get the issue link, that needs tobe shown in details of the issue eg: 'Show me the details of the issue number 5', then get the issue link of 5th issue and show the details, default value id None"
},
},
"required": ["top_issues","specific_issue_by_index" ],
},
}
]
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
messages = self.get_messages_from_history(
system_prompt=self.query_prompt_template,
model_id=self.chatgpt_model,
history=history,
user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - len(user_query_request),
few_shots=self.query_prompt_few_shots,
)
chatgpt_args = {"deployment_id": self.chatgpt_deployment} if self.openai_host == "azure" else {}
chat_completion = await openai.ChatCompletion.acreate(
**chatgpt_args,
model=self.chatgpt_model,
messages=messages,
temperature=0.0,
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
n=1,
functions=functions,
function_call="auto",
)
query_text = self.get_search_query(chat_completion, original_user_query)
if isinstance(query_text, dict):
content = query_df(query_text.get('top_issues'),
query_text.get('specific_issue', None))
self.is_streaming = False
results = []
#print(content)
else:
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
# If retrieval mode includes vectors, compute an embedding for the query
if has_vector:
embedding_args = {"deployment_id": self.embedding_deployment} if self.openai_host == "azure" else {}
embedding = await openai.Embedding.acreate(**embedding_args, model=self.embedding_model, input=query_text)
query_vector = embedding["data"][0]["embedding"]
else:
query_vector = None
# Only keep the text query if the retrieval mode uses text, otherwise drop it
if not has_text:
query_text = None
# Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text)
if overrides.get("semantic_ranker") and has_text:
r = await self.search_client.search(
query_text,
filter=filter,
query_type=QueryType.SEMANTIC,
query_language=self.query_language,
query_speller=self.query_speller,
semantic_configuration_name="azureml-default",
top=top,
query_caption="extractive|highlight-false" if use_semantic_captions else None,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="content_vector_open_ai" if query_vector else None,
)
else:
r = await self.search_client.search(
query_text,
filter=filter,
top=top,
vector=query_vector,
top_k=50 if query_vector else None,
vector_fields="content_vector_open_ai" if query_vector else None,
)
if use_semantic_captions:
results = [
doc[self.sourcepage_field] + ": " + nonewlines(" . ".join([c.text for c in doc["@search.captions"]]))
async for doc in r
]
else:
results = [doc[self.sourcepage_field] + ": " + nonewlines(doc[self.content_field]) async for doc in r]
content = "\n".join(results)
follow_up_questions_prompt = (
self.follow_up_questions_prompt_content if overrides.get("suggest_followup_questions") else ""
)
# STEP 3: Generate a contextual and content specific answer using the search results and chat history
# Allow client to replace the entire prompt, or to inject into the exiting prompt using >>>
prompt_override = overrides.get("prompt_template")
if prompt_override is None:
system_message = self.system_message_chat_conversation.format(
injected_prompt="", follow_up_questions_prompt=follow_up_questions_prompt
)
elif prompt_override.startswith(">>>"):
system_message = self.system_message_chat_conversation.format(
injected_prompt=prompt_override[3:] + "\n", follow_up_questions_prompt=follow_up_questions_prompt
)
else:
system_message = prompt_override.format(follow_up_questions_prompt=follow_up_questions_prompt)
response_token_limit = 1024
messages_token_limit = self.chatgpt_token_limit - response_token_limit
user_content_str = ""
if isinstance(query_text, dict):
if query_text['top_issues'] == 'true':
user_content_str =self.git_hub_user_prompts['top_issues']
elif query_text['specific_issue']:
user_content_str = self.git_hub_user_prompts['specific_issue']
messages = self.get_messages_from_history(
system_prompt=system_message,
model_id=self.chatgpt_model,
history=history[-6:], # last six convo due to token limit
# Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
user_content=original_user_query + "\n\nSources:\n" + content if not isinstance(query_text, dict) else
user_content_str + content,
max_tokens=messages_token_limit,
)
msg_to_display = "\n\n".join([str(message) for message in messages])
extra_info = {
"data_points": results,
"thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>"
+ msg_to_display.replace("\n", "<br>"),
}
chat_coroutine = openai.ChatCompletion.acreate(
**chatgpt_args,
model=self.chatgpt_model,
messages=messages,
temperature=overrides.get("temperature") or 0.7,
max_tokens=response_token_limit,
n=1,
stream=should_stream,
)
return (extra_info, chat_coroutine)
async def run_without_streaming(
self,
history: list[dict[str, str]],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
session_state: Any = None,
) -> dict[str, Any]:
extra_info, chat_coroutine = await self.run_until_final_call(
history, overrides, auth_claims, should_stream=False
)
chat_resp = dict(await chat_coroutine)
chat_resp["choices"][0]["context"] = extra_info
if overrides.get("suggest_followup_questions"):
content, followup_questions = self.extract_followup_questions(chat_resp["choices"][0]["message"]["content"])
chat_resp["choices"][0]["message"]["content"] = content
chat_resp["choices"][0]["context"]["followup_questions"] = followup_questions
chat_resp["choices"][0]["session_state"] = session_state
return chat_resp
async def run_with_streaming(
self,
history: list[dict[str, str]],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
session_state: Any = None,
) -> AsyncGenerator[dict, None]:
extra_info, chat_coroutine = await self.run_until_final_call(
history, overrides, auth_claims, should_stream=True
)
yield {
"choices": [
{
"delta": {"role": self.ASSISTANT},
"context": extra_info,
"session_state": session_state,
"finish_reason": None,
"index": 0,
}
],
"object": "chat.completion.chunk",
}
followup_questions_started = False
followup_content = ""
async for event in await chat_coroutine:
# "2023-07-01-preview" API version has a bug where first response has empty choices
if event["choices"]:
# if event contains << and not >>, it is start of follow-up question, truncate
content = event["choices"][0]["delta"].get("content", "")
if overrides.get("suggest_followup_questions") and "<<" in content:
followup_questions_started = True
earlier_content = content[: content.index("<<")]
if earlier_content:
event["choices"][0]["delta"]["content"] = earlier_content
yield event
followup_content += content[content.index("<<") :]
elif followup_questions_started:
followup_content += content
else:
yield event
if followup_content:
_, followup_questions = self.extract_followup_questions(followup_content)
yield {
"choices": [
{
"delta": {"role": self.ASSISTANT},
"context": {"followup_questions": followup_questions},
"finish_reason": None,
"index": 0,
}
],
"object": "chat.completion.chunk",
}
async def run(
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {}
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]:
overrides = context.get("overrides", {})
auth_claims = context.get("auth_claims", {})
if stream is False and self.is_streaming is True:
# Workaround for: https://github.com/openai/openai-python/issues/371
async with aiohttp.ClientSession() as s:
openai.aiosession.set(s)
response = await self.run_without_streaming(messages, overrides, auth_claims, session_state)
return response
else:
return self.run_with_streaming(messages, overrides, auth_claims, session_state)
def get_messages_from_history(
self,
system_prompt: str,
model_id: str,
history: list[dict[str, str]],
user_content: str,
max_tokens: int,
few_shots=[],
) -> list:
message_builder = MessageBuilder(system_prompt, model_id)
# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
for shot in reversed(few_shots):
message_builder.insert_message(shot.get("role"), shot.get("content"))
append_index = len(few_shots) + 1
message_builder.insert_message(self.USER, user_content, index=append_index)
total_token_count = message_builder.count_tokens_for_message(message_builder.messages[-1])
newest_to_oldest = list(reversed(history[:-1]))
for message in newest_to_oldest:
potential_message_count = message_builder.count_tokens_for_message(message)
if (total_token_count + potential_message_count) > max_tokens:
logging.debug("Reached max tokens of %d, history will be truncated", max_tokens)
break
message_builder.insert_message(message["role"], message["content"], index=append_index)
total_token_count += potential_message_count
return message_builder.messages
def get_search_query(self, chat_completion: dict[str, Any], user_query: str):
response_message = chat_completion["choices"][0]["message"]
if function_call := response_message.get("function_call"):
if function_call["name"] == "search_sources":
arg = json.loads(function_call["arguments"])
search_query = arg.get("search_query", self.NO_RESPONSE)
if search_query != self.NO_RESPONSE:
return search_query
elif function_call["name"] == "github_issues":
args = json.loads(function_call["arguments"])
return args
elif query_text := response_message.get("content"):
if query_text.strip() != self.NO_RESPONSE:
return query_text
return user_query
def extract_followup_questions(self, content: str):
return content.split("<<")[0], re.findall(r"<<([^>>]+)>>", content)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment