Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created July 19, 2024 21:43
Show Gist options
  • Save danyaljj/df781ecfca5e99965779e25d989788af to your computer and use it in GitHub Desktop.
Save danyaljj/df781ecfca5e99965779e25d989788af to your computer and use it in GitHub Desktop.
sql-agent.py
def get_response_sql(user_query, chat_history, plot=False):
# Specify the path to the SQLite database
db_path = "metadataDB/output_database.db"
# Connect to the SQLite database
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
underspecified = classify_underspecified_query(user_query , chat_history)
log_to_workflow(f"Query is {underspecified}")
# specify the LLM
llm = ChatOpenAI(model="gpt-4o", max_tokens=4096, temperature=0)
# create a sql agent
agent_executor = create_sql_agent(llm , db=db , agent_type="openai-tools" , verbose=True )
# log workflow to chain
sql_workflow = io.StringIO()
# Redirect standard output to the buffer
sys.stdout = sql_workflow
response = agent_executor.invoke(
{
"input" : f"user query :{user_query} ; chat history for chatbot:{chat_history}"
})
final_response = response['output']
# Reset standard output to its original value
sys.stdout = sys.__stdout__
# log sql agent workflow
print(sql_workflow.getvalue())
log_to_workflow(f"Function : get_response_sql \n Output:{sql_workflow.getvalue()}")
if final_response:
print("%" * 50)
print(final_response)
print("%" * 50)
if plot:
st.session_state.chat_history.append(
AIMessage(
content="PLOT: " + final_response
)
)
st.experimental_rerun()
else:
st.session_state.chat_history.append(
AIMessage(content=final_response)
)
refresh_suggestions(st.session_state.chat_history)
st.experimental_rerun()
return 1
else:
return "No response generated by the agent."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment