Created
January 8, 2024 07:19
-
-
Save cfahlgren1/ba17f01cf688c4229686dc3dfb4d4549 to your computer and use it in GitHub Desktop.
Natural SQL 6.7B Inference Notebook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c2bc2fc8-6a02-4a6d-bedf-e0c6ce6b86fb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: transformers==4.35.2 in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", | |
"Collecting accelerate\n", | |
" Downloading accelerate-0.25.0-py3-none-any.whl.metadata (18 kB)\n", | |
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (3.9.0)\n", | |
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (0.20.2)\n", | |
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (1.24.1)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (23.2)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (6.0.1)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (2023.12.25)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (2.31.0)\n", | |
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (0.15.0)\n", | |
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (0.4.1)\n", | |
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.35.2) (4.66.1)\n", | |
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.6)\n", | |
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu118)\n", | |
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers==4.35.2) (2023.12.2)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers==4.35.2) (4.4.0)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", | |
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.0)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", | |
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n", | |
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.35.2) (2.1.1)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.35.2) (3.4)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.35.2) (1.26.13)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.35.2) (2022.12.7)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.2)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", | |
"Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", | |
"\u001b[?25hInstalling collected packages: accelerate\n", | |
"Successfully installed accelerate-0.25.0\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", | |
"\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n", | |
"Note: you may need to restart the kernel to use updated packages.\n" | |
] | |
} | |
], | |
"source": [ | |
"pip install transformers==4.35.2 accelerate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "737ae79e-1a12-4be1-9d4b-0a2e44866ce8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "44bd39a1ba014323aaddff1e6c6b49c7", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model.safetensors.index.json: 0%| | 0.00/23.9k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "99783b2f89e942e19caf5e49185c9928", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Downloading shards: 0%| | 0/3 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "71e315fd30f34a30bb740db88a644434", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model-00001-of-00003.safetensors: 0%| | 0.00/4.94G [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2d9359d0f29e46149db8fdf0692d2066", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model-00002-of-00003.safetensors: 0%| | 0.00/4.95G [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "4fbe31a181ee4f688150b31cf9249053", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model-00003-of-00003.safetensors: 0%| | 0.00/3.59G [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "1f36be84bd1e4a67a104fb6f13cc9918", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "dc1dcd8be89a4fe28274d7f682035b4d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"cfahlgren1/NaturalSQL-6.7B-v0\")\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" \"cfahlgren1/NaturalSQL-6.7B-v0\",\n", | |
" device_map=\"auto\",\n", | |
" torch_dtype=torch.float16,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "3ddcc882-d89a-4bbf-a7cc-86d153ee9cc3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"#######\n", | |
"\n", | |
"\n", | |
"Question: Show me the day with the most users joining\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"SELECT created_at::DATE AS day, COUNT(*) AS user_count\n", | |
"FROM users\n", | |
"GROUP BY day\n", | |
"ORDER BY user_count DESC\n", | |
"LIMIT 1;\n", | |
"#######\n", | |
"\n", | |
"\n", | |
"Question: Show me the project that has a task with the most comments\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", | |
"Setting `pad_token_id` to `eos_token_id`:32023 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"SELECT p.project_name, t.task_name, COUNT(c.comment_id) AS comment_count\n", | |
"FROM projects p\n", | |
"JOIN tasks t ON p.project_id = t.project_id\n", | |
"JOIN comments c ON t.task_id = c.task_id\n", | |
"GROUP BY p.project_name, t.task_name\n", | |
"ORDER BY comment_count DESC\n", | |
"LIMIT 1;\n", | |
"#######\n", | |
"\n", | |
"\n", | |
"Question: What is the ratio of users with gmail addresses vs without?\n", | |
"SELECT \n", | |
" SUM(CASE WHEN email ILIKE '%@gmail.com%' THEN 1 ELSE 0 END)::FLOAT / NULLIF(SUM(CASE WHEN email NOT ILIKE '%@gmail.com%' THEN 1 ELSE 0 END), 0) AS gmail_ratio\n", | |
"FROM \n", | |
" users;\n" | |
] | |
} | |
], | |
"source": [ | |
"questions = ['Show me the day with the most users joining', 'Show me the project that has a task with the most comments', 'What is the ratio of users with gmail addresses vs without?']\n", | |
"\n", | |
"for question in questions:\n", | |
" prompt = f\"\"\"\n", | |
" ### Task \n", | |
"\n", | |
" Generate a SQL query to answer the following question: `{question}` \n", | |
" \n", | |
" ### Database Schema \n", | |
" The query will run on a database with the following schema: \n", | |
" ```\n", | |
" CREATE TABLE users (\n", | |
" user_id SERIAL PRIMARY KEY,\n", | |
" username VARCHAR(50) NOT NULL,\n", | |
" email VARCHAR(100) NOT NULL,\n", | |
" password_hash TEXT NOT NULL,\n", | |
" created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP\n", | |
" );\n", | |
" \n", | |
" CREATE TABLE projects (\n", | |
" project_id SERIAL PRIMARY KEY,\n", | |
" project_name VARCHAR(100) NOT NULL,\n", | |
" description TEXT,\n", | |
" start_date DATE,\n", | |
" end_date DATE,\n", | |
" owner_id INTEGER REFERENCES users(user_id)\n", | |
" );\n", | |
" \n", | |
" CREATE TABLE tasks (\n", | |
" task_id SERIAL PRIMARY KEY,\n", | |
" task_name VARCHAR(100) NOT NULL,\n", | |
" description TEXT,\n", | |
" due_date DATE,\n", | |
" status VARCHAR(50),\n", | |
" project_id INTEGER REFERENCES projects(project_id)\n", | |
" );\n", | |
" \n", | |
" CREATE TABLE taskassignments (\n", | |
" assignment_id SERIAL PRIMARY KEY,\n", | |
" task_id INTEGER REFERENCES tasks(task_id),\n", | |
" user_id INTEGER REFERENCES users(user_id),\n", | |
" assigned_date DATE NOT NULL DEFAULT CURRENT_TIMESTAMP\n", | |
" );\n", | |
" \n", | |
" CREATE TABLE comments (\n", | |
" comment_id SERIAL PRIMARY KEY,\n", | |
" content TEXT NOT NULL,\n", | |
" created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,\n", | |
" task_id INTEGER REFERENCES tasks(task_id),\n", | |
" user_id INTEGER REFERENCES users(user_id)\n", | |
" );\n", | |
" ```\n", | |
" \n", | |
" ### Answer \n", | |
" Here is the SQL query that answers the question: `{question}` \n", | |
" \"\"\"\n", | |
"\n", | |
" messages=[\n", | |
" { 'role': 'user', 'content': prompt}\n", | |
" ]\n", | |
"\n", | |
" print ([\"#\"] * 30)\n", | |
" print (\"Question: \" + question)\n", | |
" print (\"SQL: \")\n", | |
" inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", | |
" # 32021 is the id of <|EOT|> token\n", | |
" outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=32023)\n", | |
" print(tokenizer.decode(outputs[0][len(inputs[0]):]))\n", | |
" print ([\"#\\n\\n\"] * 30)\n", | |
"\n", | |
"\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fde60744-8ad1-47e0-ba74-29dfa7046eac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment