Skip to content

Instantly share code, notes, and snippets.

@jaeyow
Created July 4, 2024 01:57
Show Gist options
  • Save jaeyow/5ceba8c569fd2bef755afa44db726a91 to your computer and use it in GitHub Desktop.
Save jaeyow/5ceba8c569fd2bef755afa44db726a91 to your computer and use it in GitHub Desktop.
Snippets for the Candle Glow e-commerce Text2SQL
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = []
for table_name, table_description in all_tables:
table_schema_objs.append(
SQLTableSchema(table_name=table_name, context_str=table_description))
print(table_node_mapping)
print(table_schema_objs)
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
callback_manager=callback_manager,
)
query_engine = SQLTableRetrieverQueryEngine(
sql_database,
obj_index.as_retriever(similarity_top_k=len(all_tables)),
)
from sqlalchemy import create_engine
import os
from llama_index.core.callbacks import (
CallbackManager,
LlamaDebugHandler,
)
from dotenv import load_dotenv
load_dotenv(verbose=True, dotenv_path="../../../.env")
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
callback_manager = CallbackManager([llama_debug])
engine = create_engine(f"postgresql+psycopg2://postgres:{
os.environ["PG_VECTOR_PW"]}@localhost:5432/{os.environ["PG_VECTOR_DB"]}")
from llama_index.core import SQLDatabase
all_tables = [
("categories", "This table contains all the categories of scented candles sold in Candle Glow."),
("products", "This table contains the products list sold in Candle Glow."),
("users", "This table contains the users list who have purchased products from Candle Glow."),
("orders", "This table contains the orders list placed by users in Candle Glow."),
("order_items", "This table contains the order items list placed by users in Candle Glow."),
("reviews", "This table contains the reviews given by users for the products purchased in Candle Glow.")
]
sql_database = SQLDatabase(engine, include_tables=[table_name for table_name, _ in all_tables])
def format_sql_results_as_markdown_table(sql_results, headers):
"""
Formats the SQL results as a markdown table
"""
if not sql_results and not headers:
return "No results found."
header_row = "| " + " | ".join(headers) + " |"
separator_row = "| " + " | ".join(["---"] * len(headers)) + " |"
table_rows = []
for row in sql_results:
table_row = "|"
for j, _ in enumerate(headers):
table_row += f" {str(row[j])} |"
table_rows.append(table_row)
markdown_table = "\n".join([header_row, separator_row] + table_rows)
return markdown_table
response_template = """
### Question
**{number}.** **{question}**
### Answer
{response}
### Generated SQL Query
```sql
{sql}
```
### SQL Results
{sql_results}
"""
%%capture
! pip install llama-index
! pip install psycopg2-binary
! pip install SQLAlchemy
! pip install python-dotenv
! pip install llama-index-llms-bedrock
! pip install llama-index-embeddings-bedrock
from llama_index.core import Settings
from llama_index.llms.bedrock import Bedrock
from llama_index.embeddings.bedrock import BedrockEmbedding
Settings.llm = Bedrock(
model="anthropic.claude-3-sonnet-20240229-v1:0",
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
aws_session_token=os.environ["AWS_SESSION_TOKEN"],
region_name=os.environ["AWS_DEFAULT_REGION"],
)
Settings.embed_model = BedrockEmbedding(
model_name="cohere.embed-multilingual-v3",
region_name=os.environ["AWS_DEFAULT_REGION"],
)
def text_to_sql(query_engine, question, number=1):
"""
Calls the query engine with the given question and displays the response as a markdown cell
"""
engine_response = query_engine.query(question)
if "result" in engine_response.metadata and "col_keys" in engine_response.metadata:
display(Markdown(response_template.format(
number=number,
question=question,
response=str(engine_response),
sql=engine_response.metadata["sql_query"],
sql_results=format_sql_results_as_markdown_table(
engine_response.metadata["result"],
engine_response.metadata["col_keys"]),
)))
else:
print("No results found.")
# Testing our connection
with engine.connect() as connection:
cursor = connection.exec_driver_sql("SELECT * FROM users LIMIT 5")
print(cursor.fetchall())
with engine.connect() as connection:
cursor = connection.exec_driver_sql("SELECT * FROM categories LIMIT 5")
print(cursor.fetchall())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment