Skip to content

Instantly share code, notes, and snippets.

@phantompunk
Last active February 14, 2024 20:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phantompunk/5300dbac22215c4630dca8d6fdbb8478 to your computer and use it in GitHub Desktop.
Save phantompunk/5300dbac22215c4630dca8d6fdbb8478 to your computer and use it in GitHub Desktop.
Snippets related to my 'Tapping into GenAI' IPPON blog post
import os
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
load_dotenv()
def get_snowflake_uri() -> str:
env = os.environ.copy()
return "snowflake://{}:{}@{}/{}/{}?warehouse={}&role={}".format(
env["USERNAME"],
env["PASSWORD"],
env["ACCOUNT"],
env["DATABASE"],
env["SCHEMA"],
env["WAREHOUSE"],
env["ROLE"],
)
def main():
db = SQLDatabase.from_uri(get_snowflake_uri())
print(f"Dialect {db.dialect}, Table Names {db.get_usable_table_names()}, Table Info: {db.get_table_info()}")
if __name__ == "__main__":
main()
-- Dialect snowflake, Table Names ['claims_raw', 'cpt_codes', 'icd_codes'], Table Info:
CREATE TABLE claims_raw (
value VARIANT
)
/*
3 rows from claims_raw table:
value
{
"allowed_amount": 712,
"billed_amount": 854,
"claim_number": "2368b1d0-7d7b-47c7-99b7-655675
{
"allowed_amount": 5008,
"billed_amount": 6693,
"claim_number": "819b31c9-c690-4b3b-bb9c-0287
{
"allowed_amount": 2834,
"billed_amount": 3040,
"claim_number": "baa64aa5-6ccd-4dab-b2d9-f7bd
*/
CREATE TABLE cpt_codes (
code VARCHAR(16777216),
long_description VARCHAR(16777216),
short_description VARCHAR(16777216)
)
/*
3 rows from cpt_codes table:
code long_description short_description
0001U RED BLOOD CELL ANTIGEN TYPING, DNA, HUMAN ERYTHROCYTE ANTIGEN GENE ANALYSIS OF 35 ANTIGENS FROM 11 B Rbc dna hea 35 ag 11 bld grp
0002M LIVER DISEASE, TEN BIOCHEMICAL ASSAYS (ALT, A2-MACROGLOBULIN, APOLIPOPROTEIN A-1, TOTAL BILIRUBIN, G Liver dis 10 assays w/ash
0002U ONCOLOGY (COLORECTAL), QUANTITATIVE ASSESSMENT OF THREE URINE METABOLITES (ASCORBIC ACID, SUCCINIC A Onc clrct 3 ur metab alg plp
*/
CREATE TABLE icd_codes (
code VARCHAR(16777216),
short_description VARCHAR(16777216),
is_covered BOOLEAN
)
/*
3 rows from icd_codes table:
code short_description is_covered
A00.0 Cholera due to Vibrio cholerae 01, biovar cholerae True
A00.1 Cholera due to Vibrio cholerae 01, biovar eltor True
A00.9 Cholera, unspecified True
*/
import os
from boto3.session import boto3
from botocore.config import Config
from dotenv import load_dotenv
from langchain.chat_models.bedrock import BedrockChat
from langchain_community.utilities import SQLDatabase
load_dotenv()
def get_snowflake_uri() -> str:
env = os.environ.copy()
return "snowflake://{}:{}@{}/{}/{}?warehouse={}&role={}".format(
env["USERNAME"],
env["PASSWORD"],
env["ACCOUNT"],
env["DATABASE"],
env["SCHEMA"],
env["WAREHOUSE"],
env["ROLE"],
)
def get_bedrock_client():
region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
profile_name = os.environ.get("AWS_PROFILE")
session_kwargs = {"region_name": region, "profile_name": profile_name}
retry_config = Config(region_name=region, retries={"max_attempts":10, "mode":"standard"})
session = boto3.Session(**session_kwargs)
bedrock_client = session.client(service_name="bedrock-runtime", config=retry_config)
return bedrock_client
def main():
db = SQLDatabase.from_uri(get_snowflake_uri())
llm = BedrockChat(
model_id="anthropic.claude-instant-v1",
model_kwargs={"temperature":0},
client=get_bedrock_client()
)
if __name__ == "__main__":
main()
def get_bedrock_client():
region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
profile_name = os.environ.get("AWS_PROFILE")
session_kwargs = {"region_name": region, "profile_name": profile_name}
retry_config = Config(region_name=region, retries={"max_attempts":10, "mode":"standard"})
session = boto3.Session(**session_kwargs)
bedrock_client = session.client(service_name="bedrock-runtime", config=retry_config)
return bedrock_client
llm = BedrockChat(
model_id="anthropic.claude-instant-v1",
model_kwargs={"temperature":0},
client=get_bedrock_client()
)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many claims are there?"})
print(response)
# Here is the syntactically correct SQL query to answer the question "How many claims are there?":
#
# Question: How many claims are there?
# SQLQuery: SELECT COUNT(*) FROM claims_raw;
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many claims are there?"})
print(response)
# Here is the syntactically correct SQL query to answer the question "How many claims are there?":
#
# Question: How many claims are there?
# SQLQuery: SELECT COUNT(*) FROM claims_raw;
import os
from boto3.session import boto3
from botocore.config import Config
from dotenv import load_dotenv
from langchain.chains import create_sql_query_chain
from langchain.chat_models.bedrock import BedrockChat
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
load_dotenv()
def get_snowflake_uri() -> str:
env = os.environ.copy()
return "snowflake://{}:{}@{}/{}/{}?warehouse={}&role={}".format(
env["USERNAME"],
env["PASSWORD"],
env["ACCOUNT"],
env["DATABASE"],
env["SCHEMA"],
env["WAREHOUSE"],
env["ROLE"],
)
def get_bedrock_client():
region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
profile_name = os.environ.get("AWS_PROFILE")
session_kwargs = {"region_name": region, "profile_name": profile_name}
retry_config = Config(region_name=region, retries={"max_attempts":10, "mode":"standard"})
session = boto3.Session(**session_kwargs)
bedrock_client = session.client(service_name="bedrock-runtime", config=retry_config)
return bedrock_client
def main():
db = SQLDatabase.from_uri(get_snowflake_uri())
# llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
llm = BedrockChat(
model_id="anthropic.claude-instant-v1",
model_kwargs={"temperature": 0},
client=get_bedrock_client()
)
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query
response = chain.invoke({"question": "How many claims are there?"})
print(response)
chain = execute_query
response = chain.invoke(response.split("\n")[-1].split(":")[-1])
print(response)
if __name__ == "__main__":
main()
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
response = chain.invoke({"question": "How many claims are there?"})
print(response)
Error: (snowflake.connector.errors.ProgrammingError) 001003 (42000): SQL compilation error:
syntax error line 1 at position 0 unexpected 'Here'.
[SQL: Here is the syntactically correct SQL query to answer the question "How many claims are there?":
Question: How many claims are there?
SQLQuery: SELECT COUNT(*) FROM claims_raw;]
(Background on this error at: https://sqlalche.me/e/14/f405)
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query # | execute_query
response = chain.invoke({"question": "How many claims are there?"})
print(response)
chain = execute_query
response = chain.invoke(response.split("\n")[-1].split(":")[-1])
print(response)
# [(1000,)]
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# [(1000,)]
print(chain.get_prompts()[0].pretty_print())
# Given an input question, first create a syntactically...
# Unless the user specifies...
# Never query for all the columns...
# Pay attention to use only the column ...
# Use the following format:
#
# Question: Question here
# SQLQuery:; Query to run
# SQLResult: Result of the SQLQuery
# Answer: Final answer here
#
# Only use the following tables:
# {table_info}
#
# {input}
snowflake_prompt = """You are a Snowflake expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Do not print section headers such as 'Question:' or 'SQLQuery:'. Just print inputs as they are.
Use the following format:
--Question: Question here
--SQLQuery:; Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
{input}
"""
prompt = PromptTemplate.from_template(snowflake_prompt)
write_query = create_sql_query_chain(llm, db, prompt)
# [(1000,)]
import os
from boto3.session import boto3
from botocore.config import Config
from dotenv import load_dotenv
from langchain.chat_models.bedrock import BedrockChat
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
load_dotenv()
answer_prompt = """Given the following commented out user question, corresponding SQL query, and SQL result, answer the user question.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:
"""
snowflake_prompt = """You are a Snowflake expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Do not print section headers such as 'Question:' or 'SQLQuery:'. Just print inputs as they are.
Use the following format:
--Question: Question here
--SQLQuery:; Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
{input}
"""
def get_snowflake_uri() -> str:
env = os.environ.copy()
return "snowflake://{}:{}@{}/{}/{}?warehouse={}&role={}".format(
env["USERNAME"],
env["PASSWORD"],
env["ACCOUNT"],
env["DATABASE"],
env["SCHEMA"],
env["WAREHOUSE"],
env["ROLE"],
)
def get_bedrock_client():
region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
profile_name = os.environ.get("AWS_PROFILE")
session_kwargs = {"region_name": region, "profile_name": profile_name}
retry_config = Config(region_name=region, retries={"max_attempts":10, "mode":"standard"})
session = boto3.Session(**session_kwargs)
bedrock_client = session.client(service_name="bedrock-runtime", config=retry_config)
return bedrock_client
def main():
db = SQLDatabase.from_uri(get_snowflake_uri())
prompt = PromptTemplate.from_template(snowflake_prompt)
llm = BedrockChat(
model_id="anthropic.claude-instant-v1",
model_kwargs={"temperature": 0},
client=get_bedrock_client()
)
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt)
chain = write_query | execute_query
answer = PromptTemplate.from_template(answer_prompt)
answer = answer | llm | StrOutputParser()
chain = (
RunnablePassthrough.assign(query=write_query).assign(
result=itemgetter("query") | execute_query
) | answer
)
response = chain.invoke({"question": "How many claims are there?"})
print(response)
if __name__ == "__main__":
main()
answer = PromptTemplate.from_template(answer_prompt)
answer = answer | llm | StrOutputParser()
chain = (
RunnablePassthrough.assign(query=write_query).assign(
result=itemgetter("query") | execute_query
) | answer
)
# There are 1000 claims.
import os
from boto3.session import Config, boto3
from dotenv import load_dotenv
from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent
from langchain_openai import ChatOpenAI
from langchain.chat_models.bedrock import BedrockChat
from langchain_community.utilities import SQLDatabase
load_dotenv()
def get_snowflake_uri() -> str:
env = os.environ.copy()
return "snowflake://{}:{}@{}/{}/{}?warehouse={}&role={}".format(
env["USERNAME"],
env["PASSWORD"],
env["ACCOUNT"],
env["DATABASE"],
env["SCHEMA"],
env["WAREHOUSE"],
env["ROLE"],
)
def get_bedrock_client():
region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
profile_name = os.environ.get("AWS_PROFILE")
session_kwargs = {"region_name": region, "profile_name": profile_name}
retry_config = Config(region_name=region, retries={"max_attempts":10, "mode":"standard"})
session = boto3.Session(**session_kwargs)
bedrock_client = session.client(service_name="bedrock-runtime", config=retry_config)
return bedrock_client
def main():
db = SQLDatabase.from_uri(get_snowflake_uri())
# llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
llm = BedrockChat(
model_id="anthropic.claude-v2",
model_kwargs={"temperature":0},
client=get_bedrock_client()
)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(llm, db=db, toolkit=toolkit, max_iterations=25, verbose=True)
response = agent_executor.invoke({"input": "How many claims are there?"})
print(response)
if __name__ == "__main__":
main()
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(llm, db=db, toolkit=toolkit, verbose=True)
response = agent_executor.invoke({"input": "How many claims are there?"})
print(response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment