Created
July 28, 2023 14:14
-
-
Save RageshAntony/5666f139ec69f5a8f82d076a54fb5efe to your computer and use it in GitHub Desktop.
Query SQL databases with human lanugage
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import openai | |
import pymysql | |
from colorama import Fore | |
from tabulate import tabulate | |
def dictfetchall(cursor, fet_rows): | |
"""Returns all rows from a cursor as a list of dicts""" | |
desc = cursor.description | |
return [dict(zip([col[0] for col in desc], row)) | |
for row in fet_rows] | |
def run_query(query: str) -> str: | |
# Database credentials | |
db_user = "ragesh" | |
db_password = "ragesh" | |
db_host = "localhost" | |
db_name = "ottflix" | |
# Connect to the database | |
try: | |
connection = pymysql.connect( | |
host=db_host, | |
user=db_user, | |
password=db_password, | |
database=db_name | |
) | |
# Create a cursor | |
cursor = connection.cursor() | |
# Execute the query | |
cursor.execute(query) | |
rows = cursor.fetchall() | |
# Fetch all rows from the result set | |
# Get column names | |
column_names = [column[0] for column in cursor.description] | |
user_list = [] | |
# Prepare the data as a list of lists | |
table_data = [] | |
for row in rows: | |
table_data.append(list(row)) | |
# Converting data into json | |
# Print the result as a table | |
print(tabulate(table_data, headers=column_names, tablefmt="grid")) | |
results = dictfetchall(cursor, rows) | |
json_results: str = json.dumps(results, default=str) | |
#print(json_results) | |
# Close the cursor and connection | |
cursor.close() | |
connection.close() | |
return json_results | |
except pymysql.Error as error: | |
print("Error connecting to MySQL database:", error) | |
def query_details(query: str) -> str: | |
print(query) | |
return run_query(query) | |
# Step 1, send model the user query and what functions it has access to | |
def run_conversation(): | |
OPENAI_API_KEY = "open-ai-api-key-from-console" | |
openai.api_key = OPENAI_API_KEY | |
guide: str = open('./docs/query.sql', "r").read() | |
functions = [ | |
{ | |
"name": "query_details", | |
"description": "Create an apt SQL query for a given human language query. Return the SQL query only", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "Generated SQL Query", | |
} | |
}, | |
"required": ["query"], | |
}, | |
} | |
] | |
query_desc = "Create an apt SQL query for the given human language query. For the tables created with CREATE statements like :\n\n" + guide + "\n\n Refer the tables' comment also. Return the SQL query only. \n\n\n " | |
print(query_desc) | |
while True: | |
human_query = input(" > ") | |
question = "Human: \n\n" + human_query | |
payload = query_desc + question | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo-16k-0613", # use 'gpt-4' if available | |
messages=[{"role": "user", "content": payload}], | |
functions=functions, | |
function_call="auto", | |
) | |
message = response["choices"][0]["message"] | |
# print(Fore.YELLOW + json.dumps(message)) | |
# Step 2, check if the model wants to call a function | |
if message.get("function_call"): | |
function_name = message["function_call"]["name"] | |
if function_name == "query_details": | |
arguments = json.loads(message["function_call"]["arguments"]) | |
# print(Fore.GREEN + json.dumps(arguments)) | |
# Step 3, call the function | |
# Note: the JSON response from the model may not be valid JSON | |
function_response = query_details( | |
query=arguments["query"] | |
) | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo-16k-0613", | |
messages=[ | |
{"role": "system", | |
"content": "You are a Database bot which replies to human language queries from user"}, | |
{"role": "user", "content": human_query}, | |
{"role": "function", "name": "query_details", "content": function_response} | |
], | |
functions=functions, | |
function_call="auto", | |
) | |
print(Fore.MAGENTA + response["choices"][0]["message"]["content"]) | |
if __name__ == '__main__': | |
run_conversation() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE TABLE `episode_info` ( | |
`id` int NOT NULL AUTO_INCREMENT COMMENT 'Unique Id of the Episode ', | |
`series` int DEFAULT NULL COMMENT 'Series Id of the episode', | |
`season` int DEFAULT NULL COMMENT 'Season id of the episode ', | |
`episode_index` int DEFAULT NULL COMMENT 'Episode number of the episode ', | |
`title` mediumtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Title of the episode ', | |
`description` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Description of the episode ', | |
`air_date` date DEFAULT NULL COMMENT 'Air date of the episode ', | |
`duration_mins` int DEFAULT NULL COMMENT 'Run time of the episode ', | |
`poster_url` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Image URL for the banner of the episode ', | |
`video_url` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Trailer Video URL of the episode ', | |
PRIMARY KEY (`id`), | |
UNIQUE KEY `id_UNIQUE` (`id`), | |
KEY `series_info_idx` (`series`), | |
KEY `seasons_info_idx` (`season`), | |
CONSTRAINT `seasons_info` FOREIGN KEY (`season`) REFERENCES `seasons` (`id`), | |
CONSTRAINT `series_info` FOREIGN KEY (`series`) REFERENCES `series_info` (`id`) | |
) ENGINE=InnoDB AUTO_INCREMENT=155 DEFAULT CHARSET=utf8mb3 COLLATE=utf8mb3_unicode_ci; | |
CREATE TABLE `maturity` ( | |
`id` int NOT NULL AUTO_INCREMENT, | |
`rating_value` varchar(45) CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci DEFAULT NULL COMMENT 'Rating value of Maturity rating ', | |
`description` varchar(45) CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci DEFAULT NULL COMMENT 'Description of Maturity rating ', | |
`age_start` int DEFAULT NULL COMMENT 'Starting age of Maturity rating ', | |
PRIMARY KEY (`id`) | |
) ENGINE=InnoDB AUTO_INCREMENT=5 DEFAULT CHARSET=utf8mb3 COLLATE=utf8mb3_unicode_ci; | |
CREATE TABLE `seasons` ( | |
`id` int NOT NULL AUTO_INCREMENT COMMENT 'Id of the seasons', | |
`season_index` int DEFAULT NULL COMMENT 'Season number', | |
`info` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Info about the season', | |
`release` date DEFAULT NULL COMMENT 'Release date of a season', | |
`episode_count` int DEFAULT NULL COMMENT 'Number of episodes in a season', | |
`series` int DEFAULT NULL COMMENT 'Series id which the season belongs to, refers ‘series_info’ table', | |
`title` mediumtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Title of the season', | |
PRIMARY KEY (`id`), | |
UNIQUE KEY `id_UNIQUE` (`id`), | |
KEY `series_rel_idx` (`series`), | |
CONSTRAINT `series_rel` FOREIGN KEY (`series`) REFERENCES `series_info` (`id`) | |
) ENGINE=InnoDB AUTO_INCREMENT=3 DEFAULT CHARSET=utf8mb3 COLLATE=utf8mb3_unicode_ci; | |
CREATE TABLE `series_info` ( | |
`id` int NOT NULL AUTO_INCREMENT, | |
`plot` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci COMMENT 'The plot line of the Series', | |
`maturity` int DEFAULT NULL COMMENT 'The maturity rating of the Series, refers ‘maturity’ table', | |
`release` date DEFAULT NULL COMMENT 'The release date of the Series', | |
`seasons` int DEFAULT NULL COMMENT 'The number of seasons in the Series', | |
`trailer_url` varchar(300) CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci DEFAULT NULL COMMENT 'The trailer url of the Series', | |
`studios` int DEFAULT NULL COMMENT 'The studio that produced the Series, refers ‘studios’ table', | |
`title` mediumtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'The title of the Series', | |
`poster_url` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'The poster url of the Series', | |
PRIMARY KEY (`id`), | |
UNIQUE KEY `id_UNIQUE` (`id`), | |
KEY `maturity_idx` (`maturity`), | |
KEY `studios_idx` (`studios`), | |
CONSTRAINT `maturity` FOREIGN KEY (`maturity`) REFERENCES `maturity` (`id`), | |
CONSTRAINT `studios` FOREIGN KEY (`studios`) REFERENCES `studios` (`id`) | |
) ENGINE=InnoDB AUTO_INCREMENT=47 DEFAULT CHARSET=utf8mb3 COLLATE=utf8mb3_unicode_ci; | |
CREATE TABLE `studios` ( | |
`id` int NOT NULL AUTO_INCREMENT, | |
`name` varchar(45) CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci DEFAULT NULL COMMENT 'Name of the studio', | |
`description` longtext CHARACTER SET utf8mb3 COLLATE utf8mb3_unicode_ci COMMENT 'Description of the studio', | |
PRIMARY KEY (`id`) | |
) ENGINE=InnoDB AUTO_INCREMENT=27 DEFAULT CHARSET=utf8mb3 COLLATE=utf8mb3_unicode_ci; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment