Skip to content

Instantly share code, notes, and snippets.

@cfahlgren1
Last active January 16, 2024 04:02
Show Gist options
  • Save cfahlgren1/7b17edbc82ac6e29d53481e52e7410e9 to your computer and use it in GitHub Desktop.
Save cfahlgren1/7b17edbc82ac6e29d53481e52e7410e9 to your computer and use it in GitHub Desktop.
A test for SQLCoder-7B and NaturalSQL-6.7B on complex, multipart natural language questions over a database schema
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1a1bbfc0-8d6f-4b88-8feb-4f334d6cf220",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torch in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (2.1.1)\n",
"Requirement already satisfied: transformers in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (4.36.2)\n",
"Collecting bitsandbytes\n",
" Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)\n",
"Collecting accelerate\n",
" Downloading accelerate-0.26.1-py3-none-any.whl.metadata (18 kB)\n",
"Collecting sqlparse\n",
" Downloading sqlparse-0.4.4-py3-none-any.whl (41 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.2/41.2 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: filelock in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (4.9.0)\n",
"Requirement already satisfied: sympy in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (3.1.2)\n",
"Requirement already satisfied: fsspec in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from torch) (2023.12.2)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (0.20.1)\n",
"Requirement already satisfied: numpy>=1.17 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (1.26.2)\n",
"Requirement already satisfied: packaging>=20.0 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (2023.10.3)\n",
"Requirement already satisfied: requests in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (0.15.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (0.4.1)\n",
"Requirement already satisfied: tqdm>=4.27 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: scipy in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from bitsandbytes) (1.11.4)\n",
"Requirement already satisfied: psutil in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from accelerate) (5.9.7)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (1.26.18)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from requests->transformers) (2023.11.17)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/micromamba/envs/comfyui/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
"Downloading bitsandbytes-0.42.0-py3-none-any.whl (105.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m105.0/105.0 MB\u001b[0m \u001b[31m44.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hDownloading accelerate-0.26.1-py3-none-any.whl (270 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m270.9/270.9 kB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: sqlparse, bitsandbytes, accelerate\n",
"Successfully installed accelerate-0.26.1 bitsandbytes-0.42.0 sqlparse-0.4.4\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install torch transformers bitsandbytes accelerate sqlparse"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6dfd04f8-3ff7-40d4-8c31-f89003cb2dee",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "17d9a6e8-305b-4a5f-9755-2866cbeb6161",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ef738ac2-c335-45f7-ae8a-4fe5d7575b6b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1a65ff4bb594882b4d193a48747d751",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/915 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "742736cea8594d9ebb68c65ffd1d48c8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.model: 0%| | 0.00/493k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1d6f59c7baa4633b37c0aef97e75e90",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/1.80M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ff65e9cb4ca4e4ea8c8a89f8bb01a39",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/72.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a107ccb2bf7d488093f98592f16cec19",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/619 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6421070f390e43e99b281f9bf655b0fd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model.bin.index.json: 0%| | 0.00/23.9k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1820c10d61a1409b975f04645e1d364e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a20a2f3f751d48eab244787097dc54d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model-00001-of-00002.bin: 0%| | 0.00/9.94G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c015d44b26ba41f09a6addbd4fee2ebe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model-00002-of-00002.bin: 0%| | 0.00/4.54G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fa890f45089d409bb42f4a41d161df0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "142734bd6c4849ec9480c5caa1f33e4c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/116 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_name = \"defog/sqlcoder-7b\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" trust_remote_code=True,\n",
" torch_dtype=torch.float16,\n",
" # load_in_8bit=True,\n",
" # load_in_4bit=True,\n",
" device_map=\"auto\",\n",
" use_cache=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "750c8f94-5f4a-49f6-ba26-059b408d7d6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Which employee had the highest salary in the 'Marketing' department, and what is the total budget of that department?\n",
"SQL: WITH highest_salary AS\n",
" (SELECT employees.employeeid,\n",
" employees.firstname,\n",
" employees.lastname,\n",
" MAX(employees.salary) AS max_salary\n",
" FROM employees\n",
" JOIN departments ON employees.departmentid = departments.departmentid\n",
" WHERE departments.departmentname = 'Marketing'\n",
" GROUP BY employees.employeeid,\n",
" employees.firstname,\n",
" employees.lastname),\n",
" total_budget AS\n",
" (SELECT departments.budget\n",
" FROM departments\n",
" JOIN highest_salary ON departments.departmentid = highest_salary.employeeid)\n",
"SELECT highest_salary.firstname,\n",
" highest_salary.lastname,\n",
" total_budget.budget\n",
"FROM highest_salary,\n",
" total_budget;\n",
"\n",
"```\n",
"Question: Identify the customer who placed the highest number of orders last year, and list the names and prices of the products they ordered most frequently.\n",
"SQL: WITH customer_orders AS\n",
" (SELECT o.customerid,\n",
" COUNT(*) AS num_orders\n",
" FROM orders o\n",
" WHERE o.orderdate BETWEEN (CURRENT_DATE - interval '1 year') AND CURRENT_DATE\n",
" GROUP BY o.customerid),\n",
" product_orders AS\n",
" (SELECT p.productid,\n",
" COUNT(*) AS num_orders\n",
" FROM products p\n",
" JOIN orders o ON p.productid = o.productid\n",
" WHERE o.orderdate BETWEEN (CURRENT_DATE - interval '1 year') AND CURRENT_DATE\n",
" GROUP BY p.productid),\n",
" customer_product_orders AS\n",
" (SELECT co.customerid,\n",
" pp.productid,\n",
" co.num_orders,\n",
" pp.productname,\n",
" pp.unitprice\n",
" FROM customer_orders co\n",
" JOIN product_orders pp ON co.customerid = pp.productid)\n",
"SELECT cpo.customerid,\n",
" cpo.productname,\n",
" cpo.unitprice\n",
"FROM customer_product_orders cpo\n",
"WHERE cpo.num_orders =\n",
" (SELECT MAX(num_orders)\n",
" FROM customer_product_orders)\n",
"ORDER BY cpo.num_orders DESC NULLS LAST;\n",
"\n",
"```\n",
"Question: Find the supplier that provided the most products currently out of stock, and show the names of these out-of-stock products.\n",
"SQL: \n",
"SELECT p.productname,\n",
" s.suppliername,\n",
" COUNT(*) AS out_of_stock_products_count\n",
"FROM products p\n",
"JOIN suppliers s ON p.supplierid = s.supplierid\n",
"WHERE p.unitsinstock = 0\n",
"GROUP BY p.productname,\n",
" s.suppliername\n",
"ORDER BY out_of_stock_products_count DESC\n",
"LIMIT 1;\n",
"\n",
"```\n",
"Question: Determine the customer with the largest total order value last month, and list all the products and their quantities they ordered.\n",
"SQL: \n",
"SELECT c.customerid,\n",
" c.companyname,\n",
" o.orderid,\n",
" p.productname,\n",
" o.quantityordered,\n",
" to_char(o.orderdate, 'DD-MM-YYYY') AS order_date\n",
"FROM customers c\n",
"JOIN orders o ON c.customerid = o.customerid\n",
"JOIN orderdetails od ON o.orderid = od.orderid\n",
"JOIN products p ON od.productid = p.productid\n",
"WHERE to_date(to_char(o.orderdate, 'MM-DD-YYYY'), 'MM-DD-YYYY') BETWEEN to_date('01-01-2020', 'MM-DD-YYYY') AND CURRENT_DATE\n",
"ORDER BY total_order_value DESC\n",
"LIMIT 1;\n",
"\n",
"```\n",
"Question: Who is the employee that handled the most orders last quarter, and what is the status of their most recent order?\n",
"SQL: WITH employee_orders AS\n",
" (SELECT o.employeeid,\n",
" COUNT(o.orderid) AS order_count\n",
" FROM orders o\n",
" JOIN employees e ON o.employeeid = e.employeeid\n",
" WHERE o.orderdate BETWEEN (CURRENT_DATE - interval '1 quarter') AND CURRENT_DATE\n",
" GROUP BY o.employeeid),\n",
" employee_recent_order AS\n",
" (SELECT e.employeeid,\n",
" e.firstname,\n",
" e.lastname,\n",
" o.orderstatus\n",
" FROM employees e\n",
" JOIN orders o ON e.employeeid = o.employeeid\n",
" WHERE o.orderdate =\n",
" (SELECT MAX(orderdate)\n",
" FROM orders\n",
" WHERE employeeid = e.employeeid))\n",
"SELECT eo.employeeid,\n",
" eo.firstname,\n",
" eo.lastname,\n",
" er.orderstatus\n",
"FROM employee_orders eo\n",
"JOIN employee_recent_order er ON eo.employeeid = er.employeeid\n",
"ORDER BY eo.order_count DESC NULLS LAST\n",
"LIMIT 1;\n",
"\n",
"```\n"
]
}
],
"source": [
"questions = [\n",
" \"Which employee had the highest salary in the 'Marketing' department, and what is the total budget of that department?\",\n",
" \"Identify the customer who placed the highest number of orders last year, and list the names and prices of the products they ordered most frequently.\",\n",
" \"Find the supplier that provided the most products currently out of stock, and show the names of these out-of-stock products.\",\n",
" \"Determine the customer with the largest total order value last month, and list all the products and their quantities they ordered.\",\n",
" \"Who is the employee that handled the most orders last quarter, and what is the status of their most recent order?\"\n",
"]\n",
"\n",
"for question in questions:\n",
" \n",
" prompt = \"\"\"### Task\n",
" Generate a SQL query to answer the following question:\n",
" `{question}`\n",
"\n",
" ### PostgreSQL Database Schema\n",
" This query will run on a database whose schema is represented in this string:\n",
"\n",
" CREATE TABLE employees (\n",
" employeeid SERIAL PRIMARY KEY,\n",
" firstname VARCHAR(50),\n",
" lastname VARCHAR(50),\n",
" birthdate DATE,\n",
" email VARCHAR(100),\n",
" phone VARCHAR(20),\n",
" salary NUMERIC(10, 2),\n",
" hiredate TIMESTAMP,\n",
" departmentid INT,\n",
" address VARCHAR(200),\n",
" city VARCHAR(50),\n",
" state VARCHAR(50),\n",
" zipcode VARCHAR(10),\n",
" country VARCHAR(50),\n",
" isactive BOOLEAN\n",
" );\n",
"\n",
"CREATE TABLE departments (\n",
" departmentid SERIAL PRIMARY KEY,\n",
" departmentname VARCHAR(100),\n",
" managerid INT,\n",
" location VARCHAR(100),\n",
" phonenumber VARCHAR(20),\n",
" budget NUMERIC(10, 2),\n",
" startdate DATE,\n",
" enddate DATE,\n",
" description TEXT,\n",
" email VARCHAR(100),\n",
" website VARCHAR(100)\n",
");\n",
"\n",
"CREATE TABLE products (\n",
" productid SERIAL PRIMARY KEY,\n",
" productname VARCHAR(100),\n",
" supplierid INT,\n",
" categoryid INT,\n",
" quantityperunit VARCHAR(50),\n",
" unitprice NUMERIC(10, 2),\n",
" unitsinstock SMALLINT,\n",
" unitsonorder SMALLINT,\n",
" reorderlevel SMALLINT,\n",
" discontinued BOOLEAN,\n",
" registered TIMESTAMP,\n",
" description TEXT\n",
");\n",
"\n",
"CREATE TABLE orders (\n",
" orderid SERIAL PRIMARY KEY,\n",
" customerid INT,\n",
" employeeid INT,\n",
" orderdate TIMESTAMP,\n",
" requireddate TIMESTAMP,\n",
" shippeddate TIMESTAMP,\n",
" shipvia INT,\n",
" freight NUMERIC(10, 2),\n",
" shipname VARCHAR(100),\n",
" shipaddress VARCHAR(200),\n",
" shipcity VARCHAR(50),\n",
" shipregion VARCHAR(50),\n",
" shippostalcode VARCHAR(10),\n",
" shipcountry VARCHAR(50),\n",
" status VARCHAR(20)\n",
");\n",
"\n",
"CREATE TABLE customers (\n",
" customerid SERIAL PRIMARY KEY,\n",
" companyname VARCHAR(100),\n",
" contactname VARCHAR(100),\n",
" contacttitle VARCHAR(50),\n",
" address VARCHAR(200),\n",
" city VARCHAR(50),\n",
" region VARCHAR(50),\n",
" postalcode VARCHAR(10),\n",
" country VARCHAR(50),\n",
" phone VARCHAR(20),\n",
" fax VARCHAR(20),\n",
" email VARCHAR(100),\n",
" website VARCHAR(100),\n",
" preferredcontactmethod VARCHAR(50),\n",
" registrationdate DATE\n",
");\n",
"\n",
" ### SQL\n",
" Given the database schema, here is the SQL query that answers `{question}`:\n",
" ```sql\n",
" \"\"\".format(question=question)\n",
" eos_token_id = tokenizer.eos_token_id\n",
" \n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
" generated_ids = model.generate(\n",
" **inputs,\n",
" num_return_sequences=1,\n",
" eos_token_id=eos_token_id,\n",
" pad_token_id=eos_token_id,\n",
" max_new_tokens=400,\n",
" do_sample=False,\n",
" num_beams=1,\n",
" \n",
" )\n",
" \n",
" outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
" \n",
" import sqlparse\n",
" print (\"Question: \" + question)\n",
" print (\"SQL: \" + sqlparse.format(outputs[0].split(\"```sql\")[-1], reindent=True))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "00bd2603-1178-4dcf-bab3-67823fbec034",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "726f29e7e44a44dc803df70cc5261a6a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n",
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Which employee had the highest salary in the 'Marketing' department, and what is the total budget of that department?\n",
"SQL: SELECT e.firstname,\n",
" e.lastname,\n",
" e.salary,\n",
" d.budget\n",
"FROM employees e\n",
"JOIN departments d ON e.departmentid = d.departmentid\n",
"WHERE d.departmentname ILIKE '%Marketing%'\n",
"ORDER BY e.salary DESC\n",
"LIMIT 1;\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Identify the customer who placed the highest number of orders last year, and list the names and prices of the products they ordered most frequently.\n",
"SQL: WITH LastYearOrders AS\n",
" (SELECT customerid,\n",
" COUNT(*) AS order_count\n",
" FROM orders\n",
" WHERE orderdate >= (CURRENT_DATE - INTERVAL '1 year')\n",
" GROUP BY customerid\n",
" ORDER BY order_count DESC\n",
" LIMIT 1),\n",
" FrequentlyOrderedProducts AS\n",
" (SELECT op.orderid,\n",
" op.productid,\n",
" p.productname,\n",
" p.unitprice\n",
" FROM order_details op\n",
" JOIN products p ON op.productid = p.productid\n",
" JOIN LastYearOrders ly ON op.orderid = ly.orderid\n",
" ORDER BY op.orderid,\n",
" op.quantity DESC)\n",
"SELECT c.companyname,\n",
" fop.productname,\n",
" fop.unitprice\n",
"FROM FrequentlyOrderedProducts fop\n",
"JOIN customers c ON fop.customerid = c.customerid\n",
"JOIN orders o ON fop.orderid = o.orderid\n",
"WHERE o.orderdate >= (CURRENT_DATE - INTERVAL '1 year')\n",
"ORDER BY fop.orderid,\n",
" fop.quantity DESC;\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Find the supplier that provided the most products currently out of stock, and show the names of these out-of-stock products.\n",
"SQL: SELECT s.supplierid,\n",
" s.companyname,\n",
" p.productname\n",
"FROM suppliers s\n",
"JOIN products p ON s.supplierid = p.supplierid\n",
"WHERE p.unitsinstock = 0\n",
"GROUP BY s.supplierid,\n",
" p.productname\n",
"ORDER BY COUNT(p.productid) DESC\n",
"LIMIT 1;\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Determine the customer with the largest total order value last month, and list all the products and their quantities they ordered.\n",
"SQL: WITH LastMonthOrders AS\n",
" (SELECT o.customerid,\n",
" o.orderid,\n",
" op.productid,\n",
" op.quantity,\n",
" p.productname,\n",
" op.quantity * p.unitprice AS total_order_value\n",
" FROM orders o\n",
" JOIN order_details op ON o.orderid = op.orderid\n",
" JOIN products p ON op.productid = p.productid\n",
" WHERE o.orderdate >= date_trunc('month', CURRENT_DATE) - INTERVAL '1 month'\n",
" AND o.orderdate < date_trunc('month', CURRENT_DATE) ),\n",
" CustomerTotalOrderValue AS\n",
" (SELECT customerid,\n",
" SUM(total_order_value) AS total_value\n",
" FROM LastMonthOrders\n",
" GROUP BY customerid\n",
" ORDER BY total_value DESC\n",
" LIMIT 1)\n",
"SELECT c.customerid,\n",
" c.companyname,\n",
" lm.productid,\n",
" lm.productname,\n",
" lm.quantity\n",
"FROM CustomerTotalOrderValue ctov\n",
"JOIN LastMonthOrders lm ON ctov.customerid = lm.customerid\n",
"JOIN customers c ON c.customerid = lm.customerid;\n",
"Question: Who is the employee that handled the most orders last quarter, and what is the status of their most recent order?\n",
"SQL: WITH LastQuarterOrders AS\n",
" (SELECT employeeid,\n",
" COUNT(*) AS order_count,\n",
" MAX(orderdate) AS most_recent_order_date\n",
" FROM orders\n",
" WHERE orderdate >= date_trunc('quarter', CURRENT_DATE) - INTERVAL '3 months'\n",
" AND orderdate < date_trunc('quarter', CURRENT_DATE)\n",
" GROUP BY employeeid\n",
" ORDER BY order_count DESC\n",
" LIMIT 1),\n",
" MostRecentOrderStatus AS\n",
" (SELECT o.employeeid,\n",
" o.status\n",
" FROM orders o\n",
" JOIN LastQuarterOrders lqo ON o.employeeid = lqo.employeeid\n",
" WHERE o.orderdate = lqo.most_recent_order_date )\n",
"SELECT e.firstname,\n",
" e.lastname,\n",
" mros.status AS most_recent_order_status\n",
"FROM employees e\n",
"JOIN LastQuarterOrders lqo ON e.employeeid = lqo.employeeid\n",
"JOIN MostRecentOrderStatus mros ON lqo.employeeid = mros.employeeid;\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"cfahlgren1/NaturalSQL-6.7B-v0\")\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"cfahlgren1/NaturalSQL-6.7B-v0\",\n",
" device_map=\"auto\",\n",
" torch_dtype=torch.float16,\n",
")\n",
"\n",
"for question in questions:\n",
" \n",
" prompt = \"\"\"### Task\n",
" Generate a SQL query to answer the following question:\n",
" `{question}`\n",
" \n",
" ### PostgreSQL Database Schema\n",
" This query will run on a database whose schema is represented in this string:\n",
"\n",
" CREATE TABLE employees (\n",
" employeeid SERIAL PRIMARY KEY,\n",
" firstname VARCHAR(50),\n",
" lastname VARCHAR(50),\n",
" birthdate DATE,\n",
" email VARCHAR(100),\n",
" phone VARCHAR(20),\n",
" salary NUMERIC(10, 2),\n",
" hiredate TIMESTAMP,\n",
" departmentid INT,\n",
" address VARCHAR(200),\n",
" city VARCHAR(50),\n",
" state VARCHAR(50),\n",
" zipcode VARCHAR(10),\n",
" country VARCHAR(50),\n",
" isactive BOOLEAN\n",
");\n",
"\n",
"CREATE TABLE departments (\n",
" departmentid SERIAL PRIMARY KEY,\n",
" departmentname VARCHAR(100),\n",
" managerid INT,\n",
" location VARCHAR(100),\n",
" phonenumber VARCHAR(20),\n",
" budget NUMERIC(10, 2),\n",
" startdate DATE,\n",
" enddate DATE,\n",
" description TEXT,\n",
" email VARCHAR(100),\n",
" website VARCHAR(100)\n",
");\n",
"\n",
"CREATE TABLE products (\n",
" productid SERIAL PRIMARY KEY,\n",
" productname VARCHAR(100),\n",
" supplierid INT,\n",
" categoryid INT,\n",
" quantityperunit VARCHAR(50),\n",
" unitprice NUMERIC(10, 2),\n",
" unitsinstock SMALLINT,\n",
" unitsonorder SMALLINT,\n",
" reorderlevel SMALLINT,\n",
" discontinued BOOLEAN,\n",
" registered TIMESTAMP,\n",
" description TEXT\n",
");\n",
"\n",
"CREATE TABLE orders (\n",
" orderid SERIAL PRIMARY KEY,\n",
" customerid INT,\n",
" employeeid INT,\n",
" orderdate TIMESTAMP,\n",
" requireddate TIMESTAMP,\n",
" shippeddate TIMESTAMP,\n",
" shipvia INT,\n",
" freight NUMERIC(10, 2),\n",
" shipname VARCHAR(100),\n",
" shipaddress VARCHAR(200),\n",
" shipcity VARCHAR(50),\n",
" shipregion VARCHAR(50),\n",
" shippostalcode VARCHAR(10),\n",
" shipcountry VARCHAR(50),\n",
" status VARCHAR(20)\n",
");\n",
"\n",
"CREATE TABLE customers (\n",
" customerid SERIAL PRIMARY KEY,\n",
" companyname VARCHAR(100),\n",
" contactname VARCHAR(100),\n",
" contacttitle VARCHAR(50),\n",
" address VARCHAR(200),\n",
" city VARCHAR(50),\n",
" region VARCHAR(50),\n",
" postalcode VARCHAR(10),\n",
" country VARCHAR(50),\n",
" phone VARCHAR(20),\n",
" fax VARCHAR(20),\n",
" email VARCHAR(100),\n",
" website VARCHAR(100),\n",
" preferredcontactmethod VARCHAR(50),\n",
" registrationdate DATE\n",
");\n",
" \n",
" ### SQL\n",
" Given the database schema, here is the SQL query that answers `{question}`:\n",
" ```sql\n",
" \"\"\".format(question=question)\n",
" messages=[\n",
" { 'role': 'user', 'content': prompt}\n",
" ]\n",
" \n",
" inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
" \n",
" # 32023 is the id of <|EOT|> token\n",
" outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=32023)\n",
"\n",
" print (\"Question: \" + question)\n",
" print (\"SQL: \" + sqlparse.format(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True), reindent=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bba0e098-0e64-4979-a966-90ff337a3f7a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ComfyUI",
"language": "python",
"name": "comfyui"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment