Created
November 7, 2023 20:21
-
-
Save pamelafox/a3fdea186b687509c02cb186ca203328 to your computer and use it in GitHub Desktop.
Chat approach with additional function call
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
import re | |
from typing import Any, AsyncGenerator, Optional, Union | |
import aiohttp | |
import openai | |
from azure.search.documents.aio import SearchClient | |
from azure.search.documents.models import QueryType | |
from approaches.approach import Approach | |
from core.messagebuilder import MessageBuilder | |
from core.modelhelper import get_token_limit | |
from text import nonewlines | |
from github_issues import query_df | |
class ChatReadRetrieveReadApproach(Approach): | |
# Chat roles | |
SYSTEM = "system" | |
USER = "user" | |
ASSISTANT = "assistant" | |
NO_RESPONSE = "0" | |
""" | |
Simple retrieve-then-read implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves | |
top documents from search, then constructs a prompt with them, and then uses OpenAI to generate an completion | |
(answer) with that prompt. | |
""" | |
system_message_chat_conversation = """Assistant helps the software engineers to understand and troubleshoot issues with various systems Be brief in your answers. If someone asks questions on how can you help me, share the functions that you can perform. | |
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question. | |
For tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question. | |
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, for example [info1.txt]. Don't combine sources, list each source separately, for example [info1.txt][info2.pdf]. | |
{follow_up_questions_prompt} | |
{injected_prompt} | |
""" | |
follow_up_questions_prompt_content = """Generate three very brief follow-up questions that the user would likely ask next is about clarifying the context. | |
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for this access type?>>. | |
Try not to repeat questions that have already been asked. | |
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'""" | |
query_prompt_template = """Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base about employee healthcare plans and the employee handbook. | |
You have access to Azure Cognitive Search index with 100's of documents. | |
Generate a search query based on the conversation and the new question. | |
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms. | |
Do not include any text inside [] or <<>> in the search query terms. | |
Do not include any special characters like '+'. | |
If the question is not in English, translate the question to English before generating the search query. | |
If you cannot generate a search query, return just the number 0. | |
""" | |
query_prompt_few_shots = [ | |
{"role": USER, "content": "How to debug aml compute auth issues?"}, | |
{"role": ASSISTANT, "content": "Show steps to debug"}, | |
{"role": USER, "content": "Does training job support custom identity?"}, | |
{"role": ASSISTANT, "content": "Say 'Yes' or 'No', then show the details"}, | |
] | |
git_hub_user_prompts = { | |
"top_issues" : "Please format this into html table with indexes starting from number 1 and show the date in Pacific date time format, donot use source or citation", | |
"specific_issue" : "Donot use html table format for this, start with the 'Title:' of the Issue, Then 'Open date:', followed by the technical 'Summary:' of the issue, format it nicely with each section having header" | |
} | |
def __init__( | |
self, | |
search_client: SearchClient, | |
openai_host: str, | |
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI | |
chatgpt_model: str, | |
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text" | |
embedding_model: str, | |
sourcepage_field: str, | |
content_field: str, | |
query_language: str, | |
query_speller: str, | |
): | |
self.search_client = search_client | |
self.openai_host = openai_host | |
self.chatgpt_deployment = chatgpt_deployment | |
self.chatgpt_model = chatgpt_model | |
self.embedding_deployment = embedding_deployment | |
self.embedding_model = embedding_model | |
self.sourcepage_field = sourcepage_field | |
self.content_field = content_field | |
self.query_language = query_language | |
self.query_speller = query_speller | |
self.chatgpt_token_limit = get_token_limit(chatgpt_model) | |
self.is_streaming = True | |
async def run_until_final_call( | |
self, | |
history: list[dict[str, str]], | |
overrides: dict[str, Any], | |
auth_claims: dict[str, Any], | |
should_stream: bool = False, | |
) -> tuple: | |
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None] | |
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] | |
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False | |
top = overrides.get("top", 3) | |
filter = self.build_filter(overrides, auth_claims) | |
original_user_query = history[-1]["content"] | |
user_query_request = "Generate search query for: " + original_user_query | |
functions = [ | |
{ | |
"name": "search_sources", | |
"description": "Retrieve sources from the Azure Cognitive Search index", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"search_query": { | |
"type": "string", | |
"description": "Query string to retrieve documents from azure search eg: 'How to debug compute issues'", | |
} | |
}, | |
"required": ["search_query"], | |
}, | |
}, | |
{ | |
"name": "github_issues", | |
"description": "Retrieve Azure SDK related issues reported by users", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"top_issues" : { | |
"type": "string", | |
"description": "Set 'true' or 'false' based on if the question is on show the top issues eg: set it to 'true' if the question is 'What are the top issues'" | |
}, | |
"specific_issue" : { | |
"type": "string", | |
"description": "If the ask is on show details of an issue with a number , get the issue link, that needs tobe shown in details of the issue eg: 'Show me the details of the issue number 5', then get the issue link of 5th issue and show the details, default value id None" | |
}, | |
}, | |
"required": ["top_issues","specific_issue_by_index" ], | |
}, | |
} | |
] | |
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question | |
messages = self.get_messages_from_history( | |
system_prompt=self.query_prompt_template, | |
model_id=self.chatgpt_model, | |
history=history, | |
user_content=user_query_request, | |
max_tokens=self.chatgpt_token_limit - len(user_query_request), | |
few_shots=self.query_prompt_few_shots, | |
) | |
chatgpt_args = {"deployment_id": self.chatgpt_deployment} if self.openai_host == "azure" else {} | |
chat_completion = await openai.ChatCompletion.acreate( | |
**chatgpt_args, | |
model=self.chatgpt_model, | |
messages=messages, | |
temperature=0.0, | |
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance | |
n=1, | |
functions=functions, | |
function_call="auto", | |
) | |
query_text = self.get_search_query(chat_completion, original_user_query) | |
if isinstance(query_text, dict): | |
content = query_df(query_text.get('top_issues'), | |
query_text.get('specific_issue', None)) | |
self.is_streaming = False | |
results = [] | |
#print(content) | |
else: | |
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query | |
# If retrieval mode includes vectors, compute an embedding for the query | |
if has_vector: | |
embedding_args = {"deployment_id": self.embedding_deployment} if self.openai_host == "azure" else {} | |
embedding = await openai.Embedding.acreate(**embedding_args, model=self.embedding_model, input=query_text) | |
query_vector = embedding["data"][0]["embedding"] | |
else: | |
query_vector = None | |
# Only keep the text query if the retrieval mode uses text, otherwise drop it | |
if not has_text: | |
query_text = None | |
# Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text) | |
if overrides.get("semantic_ranker") and has_text: | |
r = await self.search_client.search( | |
query_text, | |
filter=filter, | |
query_type=QueryType.SEMANTIC, | |
query_language=self.query_language, | |
query_speller=self.query_speller, | |
semantic_configuration_name="azureml-default", | |
top=top, | |
query_caption="extractive|highlight-false" if use_semantic_captions else None, | |
vector=query_vector, | |
top_k=50 if query_vector else None, | |
vector_fields="content_vector_open_ai" if query_vector else None, | |
) | |
else: | |
r = await self.search_client.search( | |
query_text, | |
filter=filter, | |
top=top, | |
vector=query_vector, | |
top_k=50 if query_vector else None, | |
vector_fields="content_vector_open_ai" if query_vector else None, | |
) | |
if use_semantic_captions: | |
results = [ | |
doc[self.sourcepage_field] + ": " + nonewlines(" . ".join([c.text for c in doc["@search.captions"]])) | |
async for doc in r | |
] | |
else: | |
results = [doc[self.sourcepage_field] + ": " + nonewlines(doc[self.content_field]) async for doc in r] | |
content = "\n".join(results) | |
follow_up_questions_prompt = ( | |
self.follow_up_questions_prompt_content if overrides.get("suggest_followup_questions") else "" | |
) | |
# STEP 3: Generate a contextual and content specific answer using the search results and chat history | |
# Allow client to replace the entire prompt, or to inject into the exiting prompt using >>> | |
prompt_override = overrides.get("prompt_template") | |
if prompt_override is None: | |
system_message = self.system_message_chat_conversation.format( | |
injected_prompt="", follow_up_questions_prompt=follow_up_questions_prompt | |
) | |
elif prompt_override.startswith(">>>"): | |
system_message = self.system_message_chat_conversation.format( | |
injected_prompt=prompt_override[3:] + "\n", follow_up_questions_prompt=follow_up_questions_prompt | |
) | |
else: | |
system_message = prompt_override.format(follow_up_questions_prompt=follow_up_questions_prompt) | |
response_token_limit = 1024 | |
messages_token_limit = self.chatgpt_token_limit - response_token_limit | |
user_content_str = "" | |
if isinstance(query_text, dict): | |
if query_text['top_issues'] == 'true': | |
user_content_str =self.git_hub_user_prompts['top_issues'] | |
elif query_text['specific_issue']: | |
user_content_str = self.git_hub_user_prompts['specific_issue'] | |
messages = self.get_messages_from_history( | |
system_prompt=system_message, | |
model_id=self.chatgpt_model, | |
history=history[-6:], # last six convo due to token limit | |
# Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt. | |
user_content=original_user_query + "\n\nSources:\n" + content if not isinstance(query_text, dict) else | |
user_content_str + content, | |
max_tokens=messages_token_limit, | |
) | |
msg_to_display = "\n\n".join([str(message) for message in messages]) | |
extra_info = { | |
"data_points": results, | |
"thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>" | |
+ msg_to_display.replace("\n", "<br>"), | |
} | |
chat_coroutine = openai.ChatCompletion.acreate( | |
**chatgpt_args, | |
model=self.chatgpt_model, | |
messages=messages, | |
temperature=overrides.get("temperature") or 0.7, | |
max_tokens=response_token_limit, | |
n=1, | |
stream=should_stream, | |
) | |
return (extra_info, chat_coroutine) | |
async def run_without_streaming( | |
self, | |
history: list[dict[str, str]], | |
overrides: dict[str, Any], | |
auth_claims: dict[str, Any], | |
session_state: Any = None, | |
) -> dict[str, Any]: | |
extra_info, chat_coroutine = await self.run_until_final_call( | |
history, overrides, auth_claims, should_stream=False | |
) | |
chat_resp = dict(await chat_coroutine) | |
chat_resp["choices"][0]["context"] = extra_info | |
if overrides.get("suggest_followup_questions"): | |
content, followup_questions = self.extract_followup_questions(chat_resp["choices"][0]["message"]["content"]) | |
chat_resp["choices"][0]["message"]["content"] = content | |
chat_resp["choices"][0]["context"]["followup_questions"] = followup_questions | |
chat_resp["choices"][0]["session_state"] = session_state | |
return chat_resp | |
async def run_with_streaming( | |
self, | |
history: list[dict[str, str]], | |
overrides: dict[str, Any], | |
auth_claims: dict[str, Any], | |
session_state: Any = None, | |
) -> AsyncGenerator[dict, None]: | |
extra_info, chat_coroutine = await self.run_until_final_call( | |
history, overrides, auth_claims, should_stream=True | |
) | |
yield { | |
"choices": [ | |
{ | |
"delta": {"role": self.ASSISTANT}, | |
"context": extra_info, | |
"session_state": session_state, | |
"finish_reason": None, | |
"index": 0, | |
} | |
], | |
"object": "chat.completion.chunk", | |
} | |
followup_questions_started = False | |
followup_content = "" | |
async for event in await chat_coroutine: | |
# "2023-07-01-preview" API version has a bug where first response has empty choices | |
if event["choices"]: | |
# if event contains << and not >>, it is start of follow-up question, truncate | |
content = event["choices"][0]["delta"].get("content", "") | |
if overrides.get("suggest_followup_questions") and "<<" in content: | |
followup_questions_started = True | |
earlier_content = content[: content.index("<<")] | |
if earlier_content: | |
event["choices"][0]["delta"]["content"] = earlier_content | |
yield event | |
followup_content += content[content.index("<<") :] | |
elif followup_questions_started: | |
followup_content += content | |
else: | |
yield event | |
if followup_content: | |
_, followup_questions = self.extract_followup_questions(followup_content) | |
yield { | |
"choices": [ | |
{ | |
"delta": {"role": self.ASSISTANT}, | |
"context": {"followup_questions": followup_questions}, | |
"finish_reason": None, | |
"index": 0, | |
} | |
], | |
"object": "chat.completion.chunk", | |
} | |
async def run( | |
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {} | |
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]: | |
overrides = context.get("overrides", {}) | |
auth_claims = context.get("auth_claims", {}) | |
if stream is False and self.is_streaming is True: | |
# Workaround for: https://github.com/openai/openai-python/issues/371 | |
async with aiohttp.ClientSession() as s: | |
openai.aiosession.set(s) | |
response = await self.run_without_streaming(messages, overrides, auth_claims, session_state) | |
return response | |
else: | |
return self.run_with_streaming(messages, overrides, auth_claims, session_state) | |
def get_messages_from_history( | |
self, | |
system_prompt: str, | |
model_id: str, | |
history: list[dict[str, str]], | |
user_content: str, | |
max_tokens: int, | |
few_shots=[], | |
) -> list: | |
message_builder = MessageBuilder(system_prompt, model_id) | |
# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message. | |
for shot in reversed(few_shots): | |
message_builder.insert_message(shot.get("role"), shot.get("content")) | |
append_index = len(few_shots) + 1 | |
message_builder.insert_message(self.USER, user_content, index=append_index) | |
total_token_count = message_builder.count_tokens_for_message(message_builder.messages[-1]) | |
newest_to_oldest = list(reversed(history[:-1])) | |
for message in newest_to_oldest: | |
potential_message_count = message_builder.count_tokens_for_message(message) | |
if (total_token_count + potential_message_count) > max_tokens: | |
logging.debug("Reached max tokens of %d, history will be truncated", max_tokens) | |
break | |
message_builder.insert_message(message["role"], message["content"], index=append_index) | |
total_token_count += potential_message_count | |
return message_builder.messages | |
def get_search_query(self, chat_completion: dict[str, Any], user_query: str): | |
response_message = chat_completion["choices"][0]["message"] | |
if function_call := response_message.get("function_call"): | |
if function_call["name"] == "search_sources": | |
arg = json.loads(function_call["arguments"]) | |
search_query = arg.get("search_query", self.NO_RESPONSE) | |
if search_query != self.NO_RESPONSE: | |
return search_query | |
elif function_call["name"] == "github_issues": | |
args = json.loads(function_call["arguments"]) | |
return args | |
elif query_text := response_message.get("content"): | |
if query_text.strip() != self.NO_RESPONSE: | |
return query_text | |
return user_query | |
def extract_followup_questions(self, content: str): | |
return content.split("<<")[0], re.findall(r"<<([^>>]+)>>", content) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment