Last active
July 3, 2024 05:13
-
-
Save jaeyow/85d32c3e6f82156c7bab2617f7dcaed2 to your computer and use it in GitHub Desktop.
Text2SQL Gists
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
table_node_mapping = SQLTableNodeMapping(sql_database) | |
table_schema_objs = [] | |
for table_name, table_description in all_tables: | |
table_schema_objs.append( | |
SQLTableSchema(table_name=table_name, context_str=table_description)) | |
print(table_node_mapping) | |
print(table_schema_objs) | |
obj_index = ObjectIndex.from_objects( | |
table_schema_objs, | |
table_node_mapping, | |
VectorStoreIndex, | |
callback_manager=callback_manager, | |
) | |
query_engine = SQLTableRetrieverQueryEngine( | |
sql_database, | |
obj_index.as_retriever(similarity_top_k=2), | |
# callback_manager=callback_manager, # Uncomment this line to enable debugging of the query engine | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import create_engine | |
import os | |
from llama_index.core.callbacks import ( | |
CallbackManager, | |
LlamaDebugHandler, | |
) | |
from dotenv import load_dotenv | |
load_dotenv(verbose=True, dotenv_path="../../../.env") | |
llama_debug = LlamaDebugHandler(print_trace_on_end=True) | |
callback_manager = CallbackManager([llama_debug]) | |
engine = create_engine(f"postgresql+psycopg2://postgres:{ | |
os.environ["PG_VECTOR_PW"]}@localhost:5432/{os.environ["PG_VECTOR_DB"]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from llama_index.core import SQLDatabase | |
all_tables = [ | |
("people", "This table contains all the people or characters of all Star Wars episodes."), | |
("planet", "This table contains all the planets of all Star Wars episodes."),] | |
sql_database = SQLDatabase(engine, include_tables=[table_name for table_name, _ in all_tables]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def format_sql_results_as_markdown_table(sql_results, headers): | |
""" | |
Formats the SQL results as a markdown table | |
""" | |
if not sql_results and not headers: | |
return "No results found." | |
header_row = "| " + " | ".join(headers) + " |" | |
separator_row = "| " + " | ".join(["---"] * len(headers)) + " |" | |
table_rows = [] | |
for row in sql_results: | |
table_row = "|" | |
for j, _ in enumerate(headers): | |
table_row += f" {str(row[j])} |" | |
table_rows.append(table_row) | |
markdown_table = "\n".join([header_row, separator_row] + table_rows) | |
return markdown_table | |
response_template = """ | |
### Question | |
**{number}.** **{question}** | |
### Answer | |
{response} | |
### Generated SQL Query | |
```sql | |
{sql} | |
``` | |
### SQL Results | |
{sql_results} | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%capture | |
! pip install llama-index | |
! pip install psycopg2-binary | |
! pip install SQLAlchemy | |
! pip install python-dotenv | |
! pip install llama-index-llms-bedrock | |
! pip install llama-index-embeddings-bedrock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from llama_index.core import Settings | |
from llama_index.llms.bedrock import Bedrock | |
from llama_index.embeddings.bedrock import BedrockEmbedding | |
Settings.llm = Bedrock( | |
model="anthropic.claude-3-sonnet-20240229-v1:0", | |
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], | |
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], | |
aws_session_token=os.environ["AWS_SESSION_TOKEN"], | |
region_name=os.environ["AWS_DEFAULT_REGION"], | |
) | |
Settings.embed_model = BedrockEmbedding( | |
model_name="cohere.embed-multilingual-v3", | |
region_name=os.environ["AWS_DEFAULT_REGION"], | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def text_to_sql(query_engine, question, number=1): | |
""" | |
Calls the query engine with the given question and displays the response as a markdown cell | |
""" | |
engine_response = query_engine.query(question) | |
if "result" in engine_response.metadata and "col_keys" in engine_response.metadata: | |
display(Markdown(response_template.format( | |
number=number, | |
question=question, | |
response=str(engine_response), | |
sql=engine_response.metadata["sql_query"], | |
sql_results=format_sql_results_as_markdown_table( | |
engine_response.metadata["result"], | |
engine_response.metadata["col_keys"]), | |
))) | |
else: | |
print("No results found.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Testing our connection | |
with engine.connect() as connection: | |
cursor = connection.exec_driver_sql("SELECT * FROM people LIMIT 5") | |
print(cursor.fetchall()) | |
with engine.connect() as connection: | |
cursor = connection.exec_driver_sql("SELECT * FROM planet LIMIT 5") | |
print(cursor.fetchall()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment