Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save zredlined/f84c50771245ec15993b44f846c9cd0e to your computer and use it in GitHub Desktop.
Save zredlined/f84c50771245ec15993b44f846c9cd0e to your computer and use it in GitHub Desktop.
Safely Query Enterprise Databases with Langchain, OpenAI, and Gretel.ai
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPm8YlLwKOhEKBUGn3buEMl",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zredlined/f84c50771245ec15993b44f846c9cd0e/safely-query-enterprise-databases-with-langchain-openai-and-gretel-ai.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Safely query enterprise data with Langchain, OpenAI, and Gretel.ai\n",
"\n",
"In this notebook we'll demonstrate querying a database using natural language queries with a MRKL Agent-based approach using Langchain library and the 'gpt-3.5-turbo-0613' LLM from OpenAI. To demonstrate private queries to the data, we will use Gretel's Tabular-DP model to create a differentially private synthetic version of the real world data."
],
"metadata": {
"id": "ORgDtozj3yqw"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "VlWoY2Qx-vt0"
},
"outputs": [],
"source": [
"!pip install -Uqq langchain openai\n",
"!pip install -Uqq smart_open tabulate"
]
},
{
"cell_type": "code",
"source": [
"# @title Create databases from real and synthetic data\n",
"\n",
"import os\n",
"import sqlite3\n",
"\n",
"import pandas as pd\n",
"from smart_open import open\n",
"from tabulate import tabulate\n",
"\n",
"# Please paste your OpenAI API key here:\n",
"OPENAI_KEY = \"\" # @param {type:\"string\"}\n",
"SYNTHETIC_DATA = \"https://gretel-public-website.s3.us-west-2.amazonaws.com/datasets/ibm_hr_attrition/ibm-hr-synthetic-epsilon-5.csv\" # @param {type:\"string\"}\n",
"REAL_DATA = \"https://gretel-public-website.s3.us-west-2.amazonaws.com/datasets/ibm_hr_attrition/ibm-hr-employee-attrition.csv\" # @param {type:\"string\"}\n",
"VERBOSE_LANGCHAIN = False # @param {type:\"boolean\"}\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_KEY\n",
"\n",
"\n",
"def create_sqlite_db_from_csv(csv_url, db_name, table_name=\"employee_data\"):\n",
" # Connect to the SQLite database (it will be created if it doesn't exist)\n",
" conn = sqlite3.connect(db_name)\n",
"\n",
" # Load CSV file into a Pandas dataframe directly from the URL\n",
" with open(csv_url, \"r\") as f:\n",
" data_df = pd.read_csv(f)\n",
"\n",
" # Write the dataframe to the SQLite database\n",
" data_df.to_sql(table_name, conn, if_exists=\"replace\", index=False)\n",
"\n",
" # Commit and close connection\n",
" conn.commit()\n",
" conn.close()\n",
" print(f\"Database {db_name} with table {table_name} created successfully!\")\n",
"\n",
"\n",
"def run_and_compare_queries(synthetic, real, query: str):\n",
" \"\"\"Compare outputs of Langchain Agents running on real vs. synthetic data\"\"\"\n",
" query_template = f\"{query} Execute all necessary queries, and always return results to the query, no explanations or apologies please. Word wrap output every 50 characters.\"\n",
"\n",
" result1 = synthetic.run(query_template)\n",
" result2 = real.run(query_template)\n",
"\n",
" print(\"=== Comparing Results for Query ===\")\n",
" print(f\"Query: {query}\")\n",
"\n",
" table_data = [\n",
" {\"From Agent on Synthetic DB\": result1, \"From Agent on Real DB\": result2}\n",
" ]\n",
"\n",
" print(tabulate(table_data, headers=\"keys\", tablefmt=\"pretty\"))\n"
],
"metadata": {
"id": "ySD3ANEsCLw-",
"cellView": "form"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Create Langchain MRKL Agents\n",
"\n",
"from langchain.agents import AgentExecutor, create_sql_agent\n",
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
"from langchain.agents.agent_types import AgentType\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.llms.openai import OpenAI\n",
"from langchain.sql_database import SQLDatabase\n",
"\n",
"\n",
"def create_agent(\n",
" db_uri,\n",
" agent_type=AgentType.OPENAI_FUNCTIONS,\n",
" verbose=VERBOSE_LANGCHAIN,\n",
" temperature=0,\n",
" model=\"gpt-3.5-turbo-0613\",\n",
"):\n",
" db = SQLDatabase.from_uri(db_uri)\n",
" toolkit = SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=temperature))\n",
"\n",
" return create_sql_agent(\n",
" llm=ChatOpenAI(temperature=temperature, model=model),\n",
" toolkit=toolkit,\n",
" verbose=verbose,\n",
" agent_type=agent_type,\n",
" )\n",
"\n",
"\n",
"# Create SQLite databases from CSV datasets\n",
"create_sqlite_db_from_csv(\n",
" SYNTHETIC_DATA, db_name=\"synthetic-sqlite.db\", table_name=\"synthetic_ibm_attrition\"\n",
")\n",
"create_sqlite_db_from_csv(\n",
" REAL_DATA, db_name=\"real-sqlite.db\", table_name=\"real_ibm_attrition\"\n",
")\n",
"\n",
"# Create SQL agent to interact with synthetic IBM attrition data\n",
"agent_synthetic_db = create_agent(\"sqlite:////content/synthetic-sqlite.db\")\n",
"\n",
"# Create SQL agent to interact with real-world IBM attrition data\n",
"agent_real_db = create_agent(\"sqlite:////content/real-sqlite.db\")"
],
"metadata": {
"id": "7TEJhwf3-0r7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "431b34cc-f248-401d-e75f-6fb4e99b9c3c"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Database synthetic-sqlite.db with table synthetic_ibm_attrition created successfully!\n",
"Database real-sqlite.db with table real_ibm_attrition created successfully!\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Query Databases Using LangChain SQL Agents\n",
"Compare the results of real world and synthetic data on a variety of prompts"
],
"metadata": {
"id": "-K8ZhIq58xJp"
}
},
{
"cell_type": "code",
"source": [
"prompt = \"Which 3 departments have the highest attrition rates? Return a list please.\"\n",
"run_and_compare_queries(synthetic=agent_synthetic_db, real=agent_real_db, query=prompt)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PsGtqH5ZBLEJ",
"outputId": "b5f80754-4439-42b9-9261-3f2266c09afa"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"=== Comparing Results for Query ===\n",
"Query: Which 3 departments have the highest attrition rates? Return a list please.\n",
"+---------------------------------------------------------+---------------------------------------------------------+\n",
"| From Agent on Synthetic DB | From Agent on Real DB |\n",
"+---------------------------------------------------------+---------------------------------------------------------+\n",
"| The 3 departments with the highest attrition rates are: | The 3 departments with the highest attrition rates are: |\n",
"| 1. Research & Development - 144 attritions | 1. Research & Development - 133 attritions |\n",
"| 2. Sales - 77 attritions | 2. Sales - 92 attritions |\n",
"| 3. Human Resources - 9 attritions | 3. Human Resources - 12 attritions |\n",
"+---------------------------------------------------------+---------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"prompt = \"Show me a distribution of ages by 10 year increments. Return in list format please.\"\n",
"run_and_compare_queries(synthetic=agent_synthetic_db, real=agent_real_db, query=prompt)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xf75KtFhFBnG",
"outputId": "4c7faa84-99b2-454f-d010-6334eb109567"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"=== Comparing Results for Query ===\n",
"Query: Show me a distribution of ages by 10 year increments. Return in list format please.\n",
"+---------------------------------------------------------+---------------------------------------------------------+\n",
"| From Agent on Synthetic DB | From Agent on Real DB |\n",
"+---------------------------------------------------------+---------------------------------------------------------+\n",
"| Here is the distribution of ages by 10 year increments: | Here is the distribution of ages by 10-year increments: |\n",
"| | |\n",
"| - 10-19: 34 employees | - Age Group: 10-19, Count: 17 |\n",
"| - 20-29: 298 employees | - Age Group: 20-29, Count: 309 |\n",
"| - 30-39: 654 employees | - Age Group: 30-39, Count: 622 |\n",
"| - 40-49: 307 employees | - Age Group: 40-49, Count: 349 |\n",
"| - 50-59: 172 employees | - Age Group: 50-59, Count: 168 |\n",
"| - 60-69: 6 employees | - Age Group: 60-69, Count: 5 |\n",
"+---------------------------------------------------------+---------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"prompt = \"Which department travels the farthest from home?\"\n",
"run_and_compare_queries(synthetic=agent_synthetic_db, real=agent_real_db, query=prompt)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hxPe1lrjFd5Q",
"outputId": "179efa21-ba63-429c-cee1-8fd114249890"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"=== Comparing Results for Query ===\n",
"Query: Which department travels the farthest from home?\n",
"+---------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+\n",
"| From Agent on Synthetic DB | From Agent on Real DB |\n",
"+---------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+\n",
"| The department that travels the farthest from home is the Research & Development department, with a distance of 29. | The department that travels the farthest from home is the Research & Development department, with a distance of 29. |\n",
"+---------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Protecting Privacy\n",
"Here, we illustrate a \"re-identification attack\" where vulnerabilities in supposedly de-identified data allow an attacker to re-identify individuals by combining known attributes. Such risks emphasize the danger of sharing data stripped of direct identifiers yet containing attributes that, when combined, can lead to identification. Synthetic data prevents direct linking of individual information, effectively thwarting re-identification attacks and upholding privacy."
],
"metadata": {
"id": "fuote06n9MZj"
}
},
{
"cell_type": "code",
"source": [
"prompt = \"Is there an employee who is Age 46, Female, and who works in Human Resources. If so, what is their monthly income, performance rating, and years since last promotion?\"\n",
"run_and_compare_queries(synthetic=agent_synthetic_db, real=agent_real_db, query=prompt)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tUxB8cILyO_Q",
"outputId": "1d4af2d6-b5e8-4488-a865-ad70b08cf875"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"=== Comparing Results for Query ===\n",
"Query: Is there an employee who is Age 46, Female, and who works in Human Resources. If so, what is their monthly income, performance rating, and years since last promotion?\n",
"+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n",
"| From Agent on Synthetic DB | From Agent on Real DB |\n",
"+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n",
"| I'm sorry, but I don't have access to a database to execute the query and provide the results. However, you can try running the following query on your own database: | The employee who is 46 years old, female, and works in the Human Resources department has a monthly income of $3,423, a performance rating of 3, and it has been 5 years since their last promotion. |\n",
"| | |\n",
"| ```sql | |\n",
"| SELECT MonthlyIncome, PerformanceRating, YearsSinceLastPromotion | |\n",
"| FROM synthetic_ibm_attrition | |\n",
"| WHERE Age = 46 AND Gender = 'Female' AND Department = 'Human Resources' | |\n",
"| LIMIT 10 | |\n",
"| ``` | |\n",
"| | |\n",
"| This query will retrieve the monthly income, performance rating, and years since last promotion for employees who are 46 years old, female, and work in the Human Resources department. | |\n",
"+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment