-
-
Save cfahlgren1/7b17edbc82ac6e29d53481e52e7410e9 to your computer and use it in GitHub Desktop.
A test for SQLCoder-7B and NaturalSQL-6.7B on complex, multipart natural language questions over a database schema
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "1a1bbfc0-8d6f-4b88-8feb-4f334d6cf220", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: torch in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (2.1.1)\n", | |
"Requirement already satisfied: transformers in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (4.36.2)\n", | |
"Collecting bitsandbytes\n", | |
" Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)\n", | |
"Collecting accelerate\n", | |
" Downloading accelerate-0.26.1-py3-none-any.whl.metadata (18 kB)\n", | |
"Collecting sqlparse\n", | |
" Downloading sqlparse-0.4.4-py3-none-any.whl (41 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.2/41.2 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: filelock in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (3.13.1)\n", | |
"Requirement already satisfied: typing-extensions in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (4.9.0)\n", | |
"Requirement already satisfied: sympy in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (1.12)\n", | |
"Requirement already satisfied: networkx in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (3.2.1)\n", | |
"Requirement already satisfied: jinja2 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (3.1.2)\n", | |
"Requirement already satisfied: fsspec in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (2023.12.2)\n", | |
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (0.20.1)\n", | |
"Requirement already satisfied: numpy>=1.17 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (1.26.2)\n", | |
"Requirement already satisfied: packaging>=20.0 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (23.2)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (6.0.1)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (2023.10.3)\n", | |
"Requirement already satisfied: requests in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (2.31.0)\n", | |
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (0.15.0)\n", | |
"Requirement already satisfied: safetensors>=0.3.1 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (0.4.1)\n", | |
"Requirement already satisfied: tqdm>=4.27 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (4.66.1)\n", | |
"Requirement already satisfied: scipy in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from bitsandbytes) (1.11.4)\n", | |
"Requirement already satisfied: psutil in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from accelerate) (5.9.7)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n", | |
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (3.3.2)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (3.6)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (1.26.18)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (2023.11.17)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n", | |
"Downloading bitsandbytes-0.42.0-py3-none-any.whl (105.0 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m105.0/105.0 MB\u001b[0m \u001b[31m44.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", | |
"\u001b[?25hDownloading accelerate-0.26.1-py3-none-any.whl (270 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m270.9/270.9 kB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hInstalling collected packages: sqlparse, bitsandbytes, accelerate\n", | |
"Successfully installed accelerate-0.26.1 bitsandbytes-0.42.0 sqlparse-0.4.4\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", | |
"\u001b[0m" | |
] | |
} | |
], | |
"source": [ | |
"!pip install torch transformers bitsandbytes accelerate sqlparse" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "6dfd04f8-3ff7-40d4-8c31-f89003cb2dee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoTokenizer, AutoModelForCausalLM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "17d9a6e8-305b-4a5f-9755-2866cbeb6161", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.cuda.is_available()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "ef738ac2-c335-45f7-ae8a-4fe5d7575b6b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f1a65ff4bb594882b4d193a48747d751", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer_config.json: 0%| | 0.00/915 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "742736cea8594d9ebb68c65ffd1d48c8", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer.model: 0%| | 0.00/493k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b1d6f59c7baa4633b37c0aef97e75e90", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer.json: 0%| | 0.00/1.80M [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "7ff65e9cb4ca4e4ea8c8a89f8bb01a39", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"special_tokens_map.json: 0%| | 0.00/72.0 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a107ccb2bf7d488093f98592f16cec19", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"config.json: 0%| | 0.00/619 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6421070f390e43e99b281f9bf655b0fd", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"pytorch_model.bin.index.json: 0%| | 0.00/23.9k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "1820c10d61a1409b975f04645e1d364e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a20a2f3f751d48eab244787097dc54d9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"pytorch_model-00001-of-00002.bin: 0%| | 0.00/9.94G [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c015d44b26ba41f09a6addbd4fee2ebe", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"pytorch_model-00002-of-00002.bin: 0%| | 0.00/4.54G [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "fa890f45089d409bb42f4a41d161df0e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "142734bd6c4849ec9480c5caa1f33e4c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"generation_config.json: 0%| | 0.00/116 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"model_name = \"defog/sqlcoder-7b\"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" model_name,\n", | |
" trust_remote_code=True,\n", | |
" torch_dtype=torch.float16,\n", | |
" # load_in_8bit=True,\n", | |
" # load_in_4bit=True,\n", | |
" device_map=\"auto\",\n", | |
" use_cache=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "750c8f94-5f4a-49f6-ba26-059b408d7d6a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Question: Which employee had the highest salary in the 'Marketing' department, and what is the total budget of that department?\n", | |
"SQL: WITH highest_salary AS\n", | |
" (SELECT employees.employeeid,\n", | |
" employees.firstname,\n", | |
" employees.lastname,\n", | |
" MAX(employees.salary) AS max_salary\n", | |
" FROM employees\n", | |
" JOIN departments ON employees.departmentid = departments.departmentid\n", | |
" WHERE departments.departmentname = 'Marketing'\n", | |
" GROUP BY employees.employeeid,\n", | |
" employees.firstname,\n", | |
" employees.lastname),\n", | |
" total_budget AS\n", | |
" (SELECT departments.budget\n", | |
" FROM departments\n", | |
" JOIN highest_salary ON departments.departmentid = highest_salary.employeeid)\n", | |
"SELECT highest_salary.firstname,\n", | |
" highest_salary.lastname,\n", | |
" total_budget.budget\n", | |
"FROM highest_salary,\n", | |
" total_budget;\n", | |
"\n", | |
"```\n", | |
"Question: Identify the customer who placed the highest number of orders last year, and list the names and prices of the products they ordered most frequently.\n", | |
"SQL: WITH customer_orders AS\n", | |
" (SELECT o.customerid,\n", | |
" COUNT(*) AS num_orders\n", | |
" FROM orders o\n", | |
" WHERE o.orderdate BETWEEN (CURRENT_DATE - interval '1 year') AND CURRENT_DATE\n", | |
" GROUP BY o.customerid),\n", | |
" product_orders AS\n", | |
" (SELECT p.productid,\n", | |
" COUNT(*) AS num_orders\n", | |
" FROM products p\n", | |
" JOIN orders o ON p.productid = o.productid\n", | |
" WHERE o.orderdate BETWEEN (CURRENT_DATE - interval '1 year') AND CURRENT_DATE\n", | |
" GROUP BY p.productid),\n", | |
" customer_product_orders AS\n", | |
" (SELECT co.customerid,\n", | |
" pp.productid,\n", | |
" co.num_orders,\n", | |
" pp.productname,\n", | |
" pp.unitprice\n", | |
" FROM customer_orders co\n", | |
" JOIN product_orders pp ON co.customerid = pp.productid)\n", | |
"SELECT cpo.customerid,\n", | |
" cpo.productname,\n", | |
" cpo.unitprice\n", | |
"FROM customer_product_orders cpo\n", | |
"WHERE cpo.num_orders =\n", | |
" (SELECT MAX(num_orders)\n", | |
" FROM customer_product_orders)\n", | |
"ORDER BY cpo.num_orders DESC NULLS LAST;\n", | |
"\n", | |
"```\n", | |
"Question: Find the supplier that provided the most products currently out of stock, and show the names of these out-of-stock products.\n", | |
"SQL: \n", | |
"SELECT p.productname,\n", | |
" s.suppliername,\n", | |
" COUNT(*) AS out_of_stock_products_count\n", | |
"FROM products p\n", | |
"JOIN suppliers s ON p.supplierid = s.supplierid\n", | |
"WHERE p.unitsinstock = 0\n", | |
"GROUP BY p.productname,\n", | |
" s.suppliername\n", | |
"ORDER BY out_of_stock_products_count DESC\n", | |
"LIMIT 1;\n", | |
"\n", | |
"```\n", | |
"Question: Determine the customer with the largest total order value last month, and list all the products and their quantities they ordered.\n", | |
"SQL: \n", | |
"SELECT c.customerid,\n", | |
" c.companyname,\n", | |
" o.orderid,\n", | |
" p.productname,\n", | |
" o.quantityordered,\n", | |
" to_char(o.orderdate, 'DD-MM-YYYY') AS order_date\n", | |
"FROM customers c\n", | |
"JOIN orders o ON c.customerid = o.customerid\n", | |
"JOIN orderdetails od ON o.orderid = od.orderid\n", | |
"JOIN products p ON od.productid = p.productid\n", | |
"WHERE to_date(to_char(o.orderdate, 'MM-DD-YYYY'), 'MM-DD-YYYY') BETWEEN to_date('01-01-2020', 'MM-DD-YYYY') AND CURRENT_DATE\n", | |
"ORDER BY total_order_value DESC\n", | |
"LIMIT 1;\n", | |
"\n", | |
"```\n", | |
"Question: Who is the employee that handled the most orders last quarter, and what is the status of their most recent order?\n", | |
"SQL: WITH employee_orders AS\n", | |
" (SELECT o.employeeid,\n", | |
" COUNT(o.orderid) AS order_count\n", | |
" FROM orders o\n", | |
" JOIN employees e ON o.employeeid = e.employeeid\n", | |
" WHERE o.orderdate BETWEEN (CURRENT_DATE - interval '1 quarter') AND CURRENT_DATE\n", | |
" GROUP BY o.employeeid),\n", | |
" employee_recent_order AS\n", | |
" (SELECT e.employeeid,\n", | |
" e.firstname,\n", | |
" e.lastname,\n", | |
" o.orderstatus\n", | |
" FROM employees e\n", | |
" JOIN orders o ON e.employeeid = o.employeeid\n", | |
" WHERE o.orderdate =\n", | |
" (SELECT MAX(orderdate)\n", | |
" FROM orders\n", | |
" WHERE employeeid = e.employeeid))\n", | |
"SELECT eo.employeeid,\n", | |
" eo.firstname,\n", | |
" eo.lastname,\n", | |
" er.orderstatus\n", | |
"FROM employee_orders eo\n", | |
"JOIN employee_recent_order er ON eo.employeeid = er.employeeid\n", | |
"ORDER BY eo.order_count DESC NULLS LAST\n", | |
"LIMIT 1;\n", | |
"\n", | |
"```\n" | |
] | |
} | |
], | |
"source": [ | |
"questions = [\n", | |
" \"Which employee had the highest salary in the 'Marketing' department, and what is the total budget of that department?\",\n", | |
" \"Identify the customer who placed the highest number of orders last year, and list the names and prices of the products they ordered most frequently.\",\n", | |
" \"Find the supplier that provided the most products currently out of stock, and show the names of these out-of-stock products.\",\n", | |
" \"Determine the customer with the largest total order value last month, and list all the products and their quantities they ordered.\",\n", | |
" \"Who is the employee that handled the most orders last quarter, and what is the status of their most recent order?\"\n", | |
"]\n", | |
"\n", | |
"for question in questions:\n", | |
" \n", | |
" prompt = \"\"\"### Task\n", | |
" Generate a SQL query to answer the following question:\n", | |
" `{question}`\n", | |
"\n", | |
" ### PostgreSQL Database Schema\n", | |
" This query will run on a database whose schema is represented in this string:\n", | |
"\n", | |
" CREATE TABLE employees (\n", | |
" employeeid SERIAL PRIMARY KEY,\n", | |
" firstname VARCHAR(50),\n", | |
" lastname VARCHAR(50),\n", | |
" birthdate DATE,\n", | |
" email VARCHAR(100),\n", | |
" phone VARCHAR(20),\n", | |
" salary NUMERIC(10, 2),\n", | |
" hiredate TIMESTAMP,\n", | |
" departmentid INT,\n", | |
" address VARCHAR(200),\n", | |
" city VARCHAR(50),\n", | |
" state VARCHAR(50),\n", | |
" zipcode VARCHAR(10),\n", | |
" country VARCHAR(50),\n", | |
" isactive BOOLEAN\n", | |
" );\n", | |
"\n", | |
"CREATE TABLE departments (\n", | |
" departmentid SERIAL PRIMARY KEY,\n", | |
" departmentname VARCHAR(100),\n", | |
" managerid INT,\n", | |
" location VARCHAR(100),\n", | |
" phonenumber VARCHAR(20),\n", | |
" budget NUMERIC(10, 2),\n", | |
" startdate DATE,\n", | |
" enddate DATE,\n", | |
" description TEXT,\n", | |
" email VARCHAR(100),\n", | |
" website VARCHAR(100)\n", | |
");\n", | |
"\n", | |
"CREATE TABLE products (\n", | |
" productid SERIAL PRIMARY KEY,\n", | |
" productname VARCHAR(100),\n", | |
" supplierid INT,\n", | |
" categoryid INT,\n", | |
" quantityperunit VARCHAR(50),\n", | |
" unitprice NUMERIC(10, 2),\n", | |
" unitsinstock SMALLINT,\n", | |
" unitsonorder SMALLINT,\n", | |
" reorderlevel SMALLINT,\n", | |
" discontinued BOOLEAN,\n", | |
" registered TIMESTAMP,\n", | |
" description TEXT\n", | |
");\n", | |
"\n", | |
"CREATE TABLE orders (\n", | |
" orderid SERIAL PRIMARY KEY,\n", | |
" customerid INT,\n", | |
" employeeid INT,\n", | |
" orderdate TIMESTAMP,\n", | |
" requireddate TIMESTAMP,\n", | |
" shippeddate TIMESTAMP,\n", | |
" shipvia INT,\n", | |
" freight NUMERIC(10, 2),\n", | |
" shipname VARCHAR(100),\n", | |
" shipaddress VARCHAR(200),\n", | |
" shipcity VARCHAR(50),\n", | |
" shipregion VARCHAR(50),\n", | |
" shippostalcode VARCHAR(10),\n", | |
" shipcountry VARCHAR(50),\n", | |
" status VARCHAR(20)\n", | |
");\n", | |
"\n", | |
"CREATE TABLE customers (\n", | |
" customerid SERIAL PRIMARY KEY,\n", | |
" companyname VARCHAR(100),\n", | |
" contactname VARCHAR(100),\n", | |
" contacttitle VARCHAR(50),\n", | |
" address VARCHAR(200),\n", | |
" city VARCHAR(50),\n", | |
" region VARCHAR(50),\n", | |
" postalcode VARCHAR(10),\n", | |
" country VARCHAR(50),\n", | |
" phone VARCHAR(20),\n", | |
" fax VARCHAR(20),\n", | |
" email VARCHAR(100),\n", | |
" website VARCHAR(100),\n", | |
" preferredcontactmethod VARCHAR(50),\n", | |
" registrationdate DATE\n", | |
");\n", | |
"\n", | |
" ### SQL\n", | |
" Given the database schema, here is the SQL query that answers `{question}`:\n", | |
" ```sql\n", | |
" \"\"\".format(question=question)\n", | |
" eos_token_id = tokenizer.eos_token_id\n", | |
" \n", | |
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n", | |
" generated_ids = model.generate(\n", | |
" **inputs,\n", | |
" num_return_sequences=1,\n", | |
" eos_token_id=eos_token_id,\n", | |
" pad_token_id=eos_token_id,\n", | |
" max_new_tokens=400,\n", | |
" do_sample=False,\n", | |
" num_beams=1,\n", | |
" \n", | |
" )\n", | |
" \n", | |
" outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n", | |
" \n", | |
" import sqlparse\n", | |
" print (\"Question: \" + question)\n", | |
" print (\"SQL: \" + sqlparse.format(outputs[0].split(\"```sql\")[-1], reindent=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "00bd2603-1178-4dcf-bab3-67823fbec034", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "726f29e7e44a44dc803df70cc5261a6a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n", | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Question: Which employee had the highest salary in the 'Marketing' department, and what is the total budget of that department?\n", | |
"SQL: SELECT e.firstname,\n", | |
" e.lastname,\n", | |
" e.salary,\n", | |
" d.budget\n", | |
"FROM employees e\n", | |
"JOIN departments d ON e.departmentid = d.departmentid\n", | |
"WHERE d.departmentname ILIKE '%Marketing%'\n", | |
"ORDER BY e.salary DESC\n", | |
"LIMIT 1;\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Question: Identify the customer who placed the highest number of orders last year, and list the names and prices of the products they ordered most frequently.\n", | |
"SQL: WITH LastYearOrders AS\n", | |
" (SELECT customerid,\n", | |
" COUNT(*) AS order_count\n", | |
" FROM orders\n", | |
" WHERE orderdate >= (CURRENT_DATE - INTERVAL '1 year')\n", | |
" GROUP BY customerid\n", | |
" ORDER BY order_count DESC\n", | |
" LIMIT 1),\n", | |
" FrequentlyOrderedProducts AS\n", | |
" (SELECT op.orderid,\n", | |
" op.productid,\n", | |
" p.productname,\n", | |
" p.unitprice\n", | |
" FROM order_details op\n", | |
" JOIN products p ON op.productid = p.productid\n", | |
" JOIN LastYearOrders ly ON op.orderid = ly.orderid\n", | |
" ORDER BY op.orderid,\n", | |
" op.quantity DESC)\n", | |
"SELECT c.companyname,\n", | |
" fop.productname,\n", | |
" fop.unitprice\n", | |
"FROM FrequentlyOrderedProducts fop\n", | |
"JOIN customers c ON fop.customerid = c.customerid\n", | |
"JOIN orders o ON fop.orderid = o.orderid\n", | |
"WHERE o.orderdate >= (CURRENT_DATE - INTERVAL '1 year')\n", | |
"ORDER BY fop.orderid,\n", | |
" fop.quantity DESC;\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Question: Find the supplier that provided the most products currently out of stock, and show the names of these out-of-stock products.\n", | |
"SQL: SELECT s.supplierid,\n", | |
" s.companyname,\n", | |
" p.productname\n", | |
"FROM suppliers s\n", | |
"JOIN products p ON s.supplierid = p.supplierid\n", | |
"WHERE p.unitsinstock = 0\n", | |
"GROUP BY s.supplierid,\n", | |
" p.productname\n", | |
"ORDER BY COUNT(p.productid) DESC\n", | |
"LIMIT 1;\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Question: Determine the customer with the largest total order value last month, and list all the products and their quantities they ordered.\n", | |
"SQL: WITH LastMonthOrders AS\n", | |
" (SELECT o.customerid,\n", | |
" o.orderid,\n", | |
" op.productid,\n", | |
" op.quantity,\n", | |
" p.productname,\n", | |
" op.quantity * p.unitprice AS total_order_value\n", | |
" FROM orders o\n", | |
" JOIN order_details op ON o.orderid = op.orderid\n", | |
" JOIN products p ON op.productid = p.productid\n", | |
" WHERE o.orderdate >= date_trunc('month', CURRENT_DATE) - INTERVAL '1 month'\n", | |
" AND o.orderdate < date_trunc('month', CURRENT_DATE) ),\n", | |
" CustomerTotalOrderValue AS\n", | |
" (SELECT customerid,\n", | |
" SUM(total_order_value) AS total_value\n", | |
" FROM LastMonthOrders\n", | |
" GROUP BY customerid\n", | |
" ORDER BY total_value DESC\n", | |
" LIMIT 1)\n", | |
"SELECT c.customerid,\n", | |
" c.companyname,\n", | |
" lm.productid,\n", | |
" lm.productname,\n", | |
" lm.quantity\n", | |
"FROM CustomerTotalOrderValue ctov\n", | |
"JOIN LastMonthOrders lm ON ctov.customerid = lm.customerid\n", | |
"JOIN customers c ON c.customerid = lm.customerid;\n", | |
"Question: Who is the employee that handled the most orders last quarter, and what is the status of their most recent order?\n", | |
"SQL: WITH LastQuarterOrders AS\n", | |
" (SELECT employeeid,\n", | |
" COUNT(*) AS order_count,\n", | |
" MAX(orderdate) AS most_recent_order_date\n", | |
" FROM orders\n", | |
" WHERE orderdate >= date_trunc('quarter', CURRENT_DATE) - INTERVAL '3 months'\n", | |
" AND orderdate < date_trunc('quarter', CURRENT_DATE)\n", | |
" GROUP BY employeeid\n", | |
" ORDER BY order_count DESC\n", | |
" LIMIT 1),\n", | |
" MostRecentOrderStatus AS\n", | |
" (SELECT o.employeeid,\n", | |
" o.status\n", | |
" FROM orders o\n", | |
" JOIN LastQuarterOrders lqo ON o.employeeid = lqo.employeeid\n", | |
" WHERE o.orderdate = lqo.most_recent_order_date )\n", | |
"SELECT e.firstname,\n", | |
" e.lastname,\n", | |
" mros.status AS most_recent_order_status\n", | |
"FROM employees e\n", | |
"JOIN LastQuarterOrders lqo ON e.employeeid = lqo.employeeid\n", | |
"JOIN MostRecentOrderStatus mros ON lqo.employeeid = mros.employeeid;\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"cfahlgren1/NaturalSQL-6.7B-v0\")\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" \"cfahlgren1/NaturalSQL-6.7B-v0\",\n", | |
" device_map=\"auto\",\n", | |
" torch_dtype=torch.float16,\n", | |
")\n", | |
"\n", | |
"for question in questions:\n", | |
" \n", | |
" prompt = \"\"\"### Task\n", | |
" Generate a SQL query to answer the following question:\n", | |
" `{question}`\n", | |
" \n", | |
" ### PostgreSQL Database Schema\n", | |
" This query will run on a database whose schema is represented in this string:\n", | |
"\n", | |
" CREATE TABLE employees (\n", | |
" employeeid SERIAL PRIMARY KEY,\n", | |
" firstname VARCHAR(50),\n", | |
" lastname VARCHAR(50),\n", | |
" birthdate DATE,\n", | |
" email VARCHAR(100),\n", | |
" phone VARCHAR(20),\n", | |
" salary NUMERIC(10, 2),\n", | |
" hiredate TIMESTAMP,\n", | |
" departmentid INT,\n", | |
" address VARCHAR(200),\n", | |
" city VARCHAR(50),\n", | |
" state VARCHAR(50),\n", | |
" zipcode VARCHAR(10),\n", | |
" country VARCHAR(50),\n", | |
" isactive BOOLEAN\n", | |
");\n", | |
"\n", | |
"CREATE TABLE departments (\n", | |
" departmentid SERIAL PRIMARY KEY,\n", | |
" departmentname VARCHAR(100),\n", | |
" managerid INT,\n", | |
" location VARCHAR(100),\n", | |
" phonenumber VARCHAR(20),\n", | |
" budget NUMERIC(10, 2),\n", | |
" startdate DATE,\n", | |
" enddate DATE,\n", | |
" description TEXT,\n", | |
" email VARCHAR(100),\n", | |
" website VARCHAR(100)\n", | |
");\n", | |
"\n", | |
"CREATE TABLE products (\n", | |
" productid SERIAL PRIMARY KEY,\n", | |
" productname VARCHAR(100),\n", | |
" supplierid INT,\n", | |
" categoryid INT,\n", | |
" quantityperunit VARCHAR(50),\n", | |
" unitprice NUMERIC(10, 2),\n", | |
" unitsinstock SMALLINT,\n", | |
" unitsonorder SMALLINT,\n", | |
" reorderlevel SMALLINT,\n", | |
" discontinued BOOLEAN,\n", | |
" registered TIMESTAMP,\n", | |
" description TEXT\n", | |
");\n", | |
"\n", | |
"CREATE TABLE orders (\n", | |
" orderid SERIAL PRIMARY KEY,\n", | |
" customerid INT,\n", | |
" employeeid INT,\n", | |
" orderdate TIMESTAMP,\n", | |
" requireddate TIMESTAMP,\n", | |
" shippeddate TIMESTAMP,\n", | |
" shipvia INT,\n", | |
" freight NUMERIC(10, 2),\n", | |
" shipname VARCHAR(100),\n", | |
" shipaddress VARCHAR(200),\n", | |
" shipcity VARCHAR(50),\n", | |
" shipregion VARCHAR(50),\n", | |
" shippostalcode VARCHAR(10),\n", | |
" shipcountry VARCHAR(50),\n", | |
" status VARCHAR(20)\n", | |
");\n", | |
"\n", | |
"CREATE TABLE customers (\n", | |
" customerid SERIAL PRIMARY KEY,\n", | |
" companyname VARCHAR(100),\n", | |
" contactname VARCHAR(100),\n", | |
" contacttitle VARCHAR(50),\n", | |
" address VARCHAR(200),\n", | |
" city VARCHAR(50),\n", | |
" region VARCHAR(50),\n", | |
" postalcode VARCHAR(10),\n", | |
" country VARCHAR(50),\n", | |
" phone VARCHAR(20),\n", | |
" fax VARCHAR(20),\n", | |
" email VARCHAR(100),\n", | |
" website VARCHAR(100),\n", | |
" preferredcontactmethod VARCHAR(50),\n", | |
" registrationdate DATE\n", | |
");\n", | |
" \n", | |
" ### SQL\n", | |
" Given the database schema, here is the SQL query that answers `{question}`:\n", | |
" ```sql\n", | |
" \"\"\".format(question=question)\n", | |
" messages=[\n", | |
" { 'role': 'user', 'content': prompt}\n", | |
" ]\n", | |
" \n", | |
" inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", | |
" \n", | |
" # 32023 is the id of <|EOT|> token\n", | |
" outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=32023)\n", | |
"\n", | |
" print (\"Question: \" + question)\n", | |
" print (\"SQL: \" + sqlparse.format(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True), reindent=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bba0e098-0e64-4979-a966-90ff337a3f7a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "ComfyUI", | |
"language": "python", | |
"name": "comfyui" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment