Created
July 4, 2024 01:57
-
-
Save jaeyow/5ceba8c569fd2bef755afa44db726a91 to your computer and use it in GitHub Desktop.
Snippets for the Candle Glow e-commerce Text2SQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
table_node_mapping = SQLTableNodeMapping(sql_database) | |
table_schema_objs = [] | |
for table_name, table_description in all_tables: | |
table_schema_objs.append( | |
SQLTableSchema(table_name=table_name, context_str=table_description)) | |
print(table_node_mapping) | |
print(table_schema_objs) | |
obj_index = ObjectIndex.from_objects( | |
table_schema_objs, | |
table_node_mapping, | |
VectorStoreIndex, | |
callback_manager=callback_manager, | |
) | |
query_engine = SQLTableRetrieverQueryEngine( | |
sql_database, | |
obj_index.as_retriever(similarity_top_k=len(all_tables)), | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import create_engine | |
import os | |
from llama_index.core.callbacks import ( | |
CallbackManager, | |
LlamaDebugHandler, | |
) | |
from dotenv import load_dotenv | |
load_dotenv(verbose=True, dotenv_path="../../../.env") | |
llama_debug = LlamaDebugHandler(print_trace_on_end=True) | |
callback_manager = CallbackManager([llama_debug]) | |
engine = create_engine(f"postgresql+psycopg2://postgres:{ | |
os.environ["PG_VECTOR_PW"]}@localhost:5432/{os.environ["PG_VECTOR_DB"]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from llama_index.core import SQLDatabase | |
all_tables = [ | |
("categories", "This table contains all the categories of scented candles sold in Candle Glow."), | |
("products", "This table contains the products list sold in Candle Glow."), | |
("users", "This table contains the users list who have purchased products from Candle Glow."), | |
("orders", "This table contains the orders list placed by users in Candle Glow."), | |
("order_items", "This table contains the order items list placed by users in Candle Glow."), | |
("reviews", "This table contains the reviews given by users for the products purchased in Candle Glow.") | |
] | |
sql_database = SQLDatabase(engine, include_tables=[table_name for table_name, _ in all_tables]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def format_sql_results_as_markdown_table(sql_results, headers): | |
""" | |
Formats the SQL results as a markdown table | |
""" | |
if not sql_results and not headers: | |
return "No results found." | |
header_row = "| " + " | ".join(headers) + " |" | |
separator_row = "| " + " | ".join(["---"] * len(headers)) + " |" | |
table_rows = [] | |
for row in sql_results: | |
table_row = "|" | |
for j, _ in enumerate(headers): | |
table_row += f" {str(row[j])} |" | |
table_rows.append(table_row) | |
markdown_table = "\n".join([header_row, separator_row] + table_rows) | |
return markdown_table | |
response_template = """ | |
### Question | |
**{number}.** **{question}** | |
### Answer | |
{response} | |
### Generated SQL Query | |
```sql | |
{sql} | |
``` | |
### SQL Results | |
{sql_results} | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%capture | |
! pip install llama-index | |
! pip install psycopg2-binary | |
! pip install SQLAlchemy | |
! pip install python-dotenv | |
! pip install llama-index-llms-bedrock | |
! pip install llama-index-embeddings-bedrock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from llama_index.core import Settings | |
from llama_index.llms.bedrock import Bedrock | |
from llama_index.embeddings.bedrock import BedrockEmbedding | |
Settings.llm = Bedrock( | |
model="anthropic.claude-3-sonnet-20240229-v1:0", | |
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], | |
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], | |
aws_session_token=os.environ["AWS_SESSION_TOKEN"], | |
region_name=os.environ["AWS_DEFAULT_REGION"], | |
) | |
Settings.embed_model = BedrockEmbedding( | |
model_name="cohere.embed-multilingual-v3", | |
region_name=os.environ["AWS_DEFAULT_REGION"], | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def text_to_sql(query_engine, question, number=1): | |
""" | |
Calls the query engine with the given question and displays the response as a markdown cell | |
""" | |
engine_response = query_engine.query(question) | |
if "result" in engine_response.metadata and "col_keys" in engine_response.metadata: | |
display(Markdown(response_template.format( | |
number=number, | |
question=question, | |
response=str(engine_response), | |
sql=engine_response.metadata["sql_query"], | |
sql_results=format_sql_results_as_markdown_table( | |
engine_response.metadata["result"], | |
engine_response.metadata["col_keys"]), | |
))) | |
else: | |
print("No results found.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Testing our connection | |
with engine.connect() as connection: | |
cursor = connection.exec_driver_sql("SELECT * FROM users LIMIT 5") | |
print(cursor.fetchall()) | |
with engine.connect() as connection: | |
cursor = connection.exec_driver_sql("SELECT * FROM categories LIMIT 5") | |
print(cursor.fetchall()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment