Skip to content

Instantly share code, notes, and snippets.

@rwalk
Last active November 24, 2023 13:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rwalk/1350a7bc951ed1f716116de9b75e94b0 to your computer and use it in GitHub Desktop.
Save rwalk/1350a7bc951ed1f716116de9b75e94b0 to your computer and use it in GitHub Desktop.
Vector search example using star wars scripts (CPOSC 2023)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "abfba17f-8f1e-4473-8515-3bbfec0031a0",
"metadata": {},
"source": [
"# Vector Search: Star Wars Scripts Example\n",
"\n",
"This is a demo I presented at the 2023 Central Pennsylvania Open Source Conference (CPOSC). Slides from this talk are available here:\n",
"\n",
"https://docs.google.com/presentation/d/1h-wodY9c5ljfKYy0cuMr85Yuc592j63fUCX0JyrxxKU\n",
"\n",
"In this notebook, we demonstrate the use of vector search against the scripts from the Star Wars episodes IV, V, VI. All data is from Kaggle:\n",
"\n",
"https://www.kaggle.com/datasets/xvivancos/star-wars-movie-scripts\n",
"\n",
"We use the open source Weaviate search engine for vector search. Weaviate can be run locally or with their paid cloud service. The demo below uses the cloud service. There are many options for vector search engines. These include offerings from Elastic, Pinecone, and milvus. \n",
"\n",
"https://weaviate.io/\n",
"\n",
"For vectorizing, we use Cohere embeddings (not open source). There are many other options for vectorizing, including BERT, OpenAI embeddings. Cohere embeddings seemlessly integrate with the Weaviate cloud service making it easy to try this out on a trial account.\n",
"\n",
"https://cohere.ai/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54690305-9567-432e-9c31-d5c1470f8333",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# install the weaviate-client\n",
"%pip install weaviate-client"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b29c0e5e-2e82-43e7-ac41-819cc745536d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# connect to weaviate cluster\n",
"import weaviate\n",
"import json\n",
"import os\n",
"\n",
"# WEVIATE API KEY\n",
"weaviate_auth = weaviate.AuthApiKey(os.environ[\"WEAVIATE_API_KEY\"])\n",
"\n",
"# you'll need to set a cohere key in your environment or hardcode it here\n",
"cohere_api_key = os.environ[\"COHERE_API_KEY\"]\n",
"\n",
"# Set the URL for your weaviate cluster\n",
"WEAVIATE_URL = os.environ[\"WEAVIATE_URL\"]\n",
"\n",
"client = weaviate.Client(WEAVIATE_URL,\n",
" auth_client_secret=weaviate_auth,\n",
" additional_headers={\n",
" \"X-Cohere-Api-Key\": cohere_api_key\n",
" })\n",
"print(json.dumps(client.schema.get(), indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb2e72ed-20ec-403a-b05a-bcc65686799d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# create a schema to hold our lines\n",
"line_schema = {\n",
" \"class\": \"Line\",\n",
" \"description\": \"A line from a script\",\n",
" \"vectorizer\": \"text2vec-cohere\",\n",
" \"vectorIndexConfig\": {\n",
" \"distance\": \"dot\"\n",
" },\n",
" \"moduleConfig\": {\n",
" \"text2vec-cohere\": {\n",
" \"model\": \"multilingual-22-12\",\n",
" \"truncate\": \"RIGHT\"\n",
" }\n",
" },\n",
" \"properties\": [\n",
" {\n",
" \"dataType\": [\"string\"], \n",
" \"description\": \"The title of script from which this line came\",\n",
" \"name\": \"scriptTitle\"\n",
" },\n",
" {\n",
" \"dataType\": [\"int\"],\n",
" \"description\": \"The line number\",\n",
" \"name\": \"number\"\n",
" },\n",
" {\n",
" \"dataType\": [\"string\"],\n",
" \"description\": \"The speaker of the line\",\n",
" \"name\": \"speaker\"\n",
" },\n",
" {\n",
" \"dataType\": [\"text\"],\n",
" \"description\": \"The text of the line\",\n",
" \"moduleConfig\": {\n",
" \"text2vec-cohere\": {\n",
" \"skip\": False,\n",
" \"vectorizePropertyName\": False\n",
" }\n",
" },\n",
" \"name\": \"text\"\n",
" } \n",
" ]\n",
"}\n",
"\n",
"# add the schema to the index\n",
"client.schema.create_class(line_schema)\n",
"print(json.dumps(client.schema.get(), indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82f5334e-3e82-4a12-9a0f-6604341ed564",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"# Configure a batch process\n",
"client.batch.configure(\n",
" batch_size=100, \n",
" dynamic=True,\n",
" timeout_retries=3,\n",
" callback=None,\n",
")\n",
"\n",
"DATAPATH = \"./starwars/\" # path to your starwars script files\n",
"\n",
"\n",
"def parse_script(filepath):\n",
" '''\n",
" this is a not fancy parser for the slightly odd file format of the star wars scripts\n",
" '''\n",
" script_title = filepath.name\n",
" lines = filepath.read_text().split(\"\\n\")[1:]\n",
" for line in lines:\n",
" if len(line.strip()) > 0:\n",
" number, speaker, text = line.split(\" \", maxsplit=2)\n",
" number = int(number.strip('\"'))\n",
" speaker = speaker.strip('\"').strip()\n",
" text = text.strip('\"').strip()\n",
"\n",
" # return a parsed Line object\n",
" yield {\n",
" \"scriptTitle\": script_title,\n",
" \"number\": number,\n",
" \"speaker\": speaker,\n",
" \"text\": text\n",
" }\n",
"\n",
"# Batch import all Line objects\n",
"with client.batch as batch:\n",
" \n",
" datapath = Path(DATAPATH)\n",
" for filepath in datapath.iterdir():\n",
" if filepath.match(\"*.txt\"):\n",
" print(f\"Working on {filepath}\")\n",
" for line_obj in parse_script(filepath):\n",
" client.batch.add_data_object(line_obj, \"Line\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef7abf3d-f18a-437d-b1fb-2a045a52916c",
"metadata": {},
"outputs": [],
"source": [
"# Do some searches!\n",
"\n",
"# Hybrid fraction controls how keyword and vector searches are combined together\n",
"# Set to 0 for pure keyword\n",
"# Set to 1 for pure vector search\n",
"# Set to 0.5 for a 50/50 mix\n",
"hybrid_fraction = 1.0\n",
"\n",
"result = client.query.get(\"Line\", [\"speaker\", \"text\"]).with_hybrid(\"Kid, I'm your dad\", hybrid_fraction).do()\n",
"for hit in result[\"data\"][\"Get\"][\"Line\"][0:10]:\n",
" print(json.dumps(hit, indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3280f35f-c35d-4f8e-9f47-806d073aa4ce",
"metadata": {},
"outputs": [],
"source": [
"# If you need to delete your index and start over:\n",
"client.schema.delete_class(\"Line\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment