Skip to content

Instantly share code, notes, and snippets.

@kristopherjohnson
Last active February 1, 2023 17:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kristopherjohnson/fe1f6275c84cca71954419d26f0f9bdd to your computer and use it in GitHub Desktop.
Save kristopherjohnson/fe1f6275c84cca71954419d26f0f9bdd to your computer and use it in GitHub Desktop.
Search code using OpenAI API
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Code Search\n",
"\n",
"Search for functions in a Python codebase that match an English query, using the OpenAI API.\n",
"\n",
"Adapted from sample code at <https://github.com/openai/openai-cookbook/blob/main/examples/Code_search.ipynb>\n",
"\n",
"Dependencies: `pip3 install matplotlib openai pandas plotly python-dotenv scikit-learn`\n",
"\n",
"To use this notebook, set appropriate configuration values in the following cell, run all the cells, and then go down to the bottom and make calls to the `search_functions()` function.\n",
"\n",
"The first time the notebook is run, it will read files from the codebase, generate embeddings for all the functions, and then store that data in `code_search_embeddings.csv`. On any subsequent runs, if that file exists the notebook will read it instead of reading anything from the codebase. If the codebase changes significantly, you can delete that file to force the embeddings data to be recreated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ast\n",
"import os\n",
"from glob import glob\n",
"\n",
"import dotenv\n",
"import openai\n",
"import pandas as pd\n",
"\n",
"from openai.embeddings_utils import get_embedding, cosine_similarity\n",
"\n",
"\n",
"# Configuration Values\n",
"\n",
"# Use .env for overrides of these variables\n",
"dotenv.load_dotenv('.env')\n",
"\n",
"CODE_ROOT_DIR = os.getenv('CODE_ROOT_DIR', os.path.expanduser(\"~/kobiton/ita-ai-service/src/\"))\n",
"EMBEDDINGS_CSV = os.getenv('EMBEDDINGS_CSV', 'code_search_embeddings.csv')\n",
"\n",
"# OpenAI currently recommends text-embedding-ada-002 for all use cases.\n",
"# See <https://openai.com/blog/new-and-improved-embedding-model/>\n",
"EMBEDDING_ENGINE = os.getenv('EMBEDDING_ENGINE', 'text-embedding-ada-002')\n",
"\n",
"# Create OpenAI API key at https://beta.openai.com/account/api-keys\n",
"# \n",
"# To avoid storing the API key in plaintext in this file, create a file .env in\n",
"# this directory and add a line like this:\n",
"#\n",
"# OPENAI_API_KEY=ab-cdefghijklmnopqrstuvwxyz\n",
"#\n",
"OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_function_name(code):\n",
" \"\"\"\n",
" Extract function name from a line beginning with \"def \"\n",
" \"\"\"\n",
" assert code.startswith(\"def \")\n",
" return code[len(\"def \"): code.index(\"(\")]\n",
"\n",
"def get_until_no_space(all_lines, i) -> str:\n",
" \"\"\"\n",
" Get all lines until a line outside the function definition is found.\n",
" \"\"\"\n",
" ret = [all_lines[i]]\n",
" for j in range(i + 1, i + 10000):\n",
" if j < len(all_lines):\n",
" if len(all_lines[j]) == 0 or all_lines[j][0] in [\" \", \"\\t\", \")\"]:\n",
" ret.append(all_lines[j])\n",
" else:\n",
" break\n",
" return \"\\n\".join(ret)\n",
"\n",
"def get_functions(filepath):\n",
" \"\"\"\n",
" Get all functions in a Python file.\n",
" \"\"\"\n",
" try:\n",
" whole_code = open(filepath).read().replace(\"\\r\", \"\\n\")\n",
" except UnicodeDecodeError:\n",
" # This can happen if a .py file isn't a valid UTF-8 file.\n",
" # Just skip this file.\n",
" return\n",
" all_lines = whole_code.split(\"\\n\")\n",
" for i, l in enumerate(all_lines):\n",
" if l.startswith(\"def \"):\n",
" try:\n",
" code = get_until_no_space(all_lines, i)\n",
" function_name = get_function_name(code)\n",
" yield {\"code\": code, \"function_name\": function_name, \"filepath\": filepath}\n",
" except ValueError:\n",
" # This can happen if a line starting with \"def \" does not also contain \"(\".\n",
" # Just skip this function.\n",
" continue\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"openai.api_key = OPENAI_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Read already-generated embeddings if available; otherwise create them.\n",
"\n",
"if os.path.exists(EMBEDDINGS_CSV):\n",
" df = pd.read_csv(EMBEDDINGS_CSV)\n",
" \n",
" # Convert string representations back to float vectors\n",
" df['code_embedding'] = df['code_embedding'].apply(ast.literal_eval)\n",
"else:\n",
" code_root = CODE_ROOT_DIR\n",
"\n",
" code_files=[]\n",
" for dirpath, dirnames, filenames in os.walk(code_root, topdown=True):\n",
" code_files.extend(glob(os.path.join(dirpath, '*.py')))\n",
" # Don't descend into any directories named 'venv', which contain Python virtual environments\n",
" if 'venv' in dirnames:\n",
" dirnames.remove('venv')\n",
" print(\"Total number of py files:\", len(code_files))\n",
"\n",
" if len(code_files) == 0:\n",
" print(f\"No py files were found in {code_root}.\")\n",
"\n",
" all_funcs = []\n",
" for code_file in code_files:\n",
" funcs = list(get_functions(code_file))\n",
" for func in funcs:\n",
" all_funcs.append(func)\n",
"\n",
" print(\"Total number of functions extracted:\", len(all_funcs))\n",
" df = pd.DataFrame(all_funcs)\n",
" \n",
" df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, engine=EMBEDDING_ENGINE))\n",
" df['filepath'] = df['filepath'].apply(lambda x: x.replace(code_root, \"\"))\n",
" df.to_csv(EMBEDDINGS_CSV, index=False)\n",
"\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def search_functions(df, code_query, n=3, pprint=True, n_lines=7):\n",
" \"\"\"\n",
" Get functions that most-closely match the query.\n",
"\n",
" Returns a DataFrame with the n closest matches from the given DataFrame.\n",
" \"\"\"\n",
" embedding = get_embedding(code_query, engine=EMBEDDING_ENGINE)\n",
" df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))\n",
"\n",
" res = df.sort_values('similarities', ascending=False).head(n)\n",
" if pprint:\n",
" for r in res.iterrows():\n",
" print(r[1].filepath+\":\"+r[1].function_name + \" score=\" + str(round(r[1].similarities, 3)))\n",
" print(\"\\n\".join(r[1].code.split(\"\\n\")[:n_lines]))\n",
" print('-'*70)\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = search_functions(df, 'load weights and apply them to feature vector', n=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "9ae124e314517d9aba51bbd4e2667cd3cc71250be45fb6e7fcc97c3a54703858"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment