Skip to content

Instantly share code, notes, and snippets.

@rwalk
Created January 28, 2023 15:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rwalk/a0039464613fd1013ed44c49152173dc to your computer and use it in GitHub Desktop.
Save rwalk/a0039464613fd1013ed44c49152173dc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "abfba17f-8f1e-4473-8515-3bbfec0031a0",
"metadata": {},
"source": [
"# Vector Search: Star Wars Scripts Example\n",
"\n",
"This is a demo I presented at the Tech Lancaster Meetup in January 2023. Slides from this talk are available here:\n",
"\n",
"https://docs.google.com/presentation/d/1oU0mMTKK56ML4aBtthCx7tbKmLr15q4pO6tNrdMRxts\n",
"\n",
"\n",
"In this notebook, we demonstrate the use of vector search against the scripts from the Star Wars episodes IV, V, VI. All data is from Kaggle:\n",
"\n",
"https://www.kaggle.com/jsphyg/star-wars\n",
"\n",
"We use the open source Weaviate search engine for vector search. Weaviate can be run locally or with their paid cloud service. The demo below uses the cloud service. There are many options for vector search engines. These include offerings from Elastic, Pinecone, and milvus. \n",
"\n",
"https://weaviate.io/\n",
"\n",
"For vectorizing, we use Cohere embeddings (not open source). There are many options for vectorizing, including BERT, OpenAI embeddings. Cohere embeddings seemlessly integrate with the Weaviate cloud service making it easy to try this out on a trial account.\n",
"\n",
"https://cohere.ai/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54690305-9567-432e-9c31-d5c1470f8333",
"metadata": {},
"outputs": [],
"source": [
"# install the weaviate-client\n",
"%pip install weaviate-client"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b29c0e5e-2e82-43e7-ac41-819cc745536d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# connect to weaviate cluster\n",
"import weaviate\n",
"import json\n",
"import os\n",
"\n",
"# you'll need to set a cohere key in your environment or hardcode it here\n",
"COHERE_KEY = os.environ[\"COHERE_KEY\"]\n",
"\n",
"# Set the URL for your weaviate cluster\n",
"WEAVIATE_URL = \"https://<your-url-subdomain>.weaviate.network\"\n",
"\n",
"client = weaviate.Client(WEAVIATE_URL,\n",
" additional_headers={\n",
" \"X-Cohere-Api-Key\": COHERE_KEY\n",
" } \n",
")\n",
"print(json.dumps(client.schema.get(), indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb2e72ed-20ec-403a-b05a-bcc65686799d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# create a schema to hold our lines\n",
"line_schema = {\n",
" \"class\": \"Line\",\n",
" \"description\": \"A line from a script\",\n",
" \"vectorizer\": \"text2vec-cohere\",\n",
" \"vectorIndexConfig\": {\n",
" \"distance\": \"dot\"\n",
" },\n",
" \"moduleConfig\": {\n",
" \"text2vec-cohere\": {\n",
" \"model\": \"multilingual-22-12\",\n",
" \"truncate\": \"RIGHT\"\n",
" }\n",
" },\n",
" \"properties\": [\n",
" {\n",
" \"dataType\": [\"string\"], \n",
" \"description\": \"The title of script from which this line came\",\n",
" \"name\": \"scriptTitle\"\n",
" },\n",
" {\n",
" \"dataType\": [\"int\"],\n",
" \"description\": \"The line number\",\n",
" \"name\": \"number\"\n",
" },\n",
" {\n",
" \"dataType\": [\"string\"],\n",
" \"description\": \"The speaker of the line\",\n",
" \"name\": \"speaker\"\n",
" },\n",
" {\n",
" \"dataType\": [\"text\"],\n",
" \"description\": \"The text of the line\",\n",
" \"name\": \"text\"\n",
" } \n",
" ]\n",
"}\n",
"\n",
"# add the schema to the index\n",
"client.schema.create_class(line_schema)\n",
"print(json.dumps(client.schema.get(), indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82f5334e-3e82-4a12-9a0f-6604341ed564",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"# Configure a batch process\n",
"client.batch.configure(\n",
" batch_size=100, \n",
" dynamic=True,\n",
" timeout_retries=3,\n",
" callback=None,\n",
")\n",
"\n",
"DATAPATH = \"./Downloads/starwars/\" # path to your starwars script files\n",
"\n",
"\n",
"def parse_script(filepath):\n",
" '''\n",
" this is a not fancy parser for the slightly odd file format of the star wars scripts\n",
" '''\n",
" script_title = filepath.name\n",
" lines = filepath.read_text().split(\"\\n\")[1:]\n",
" for line in lines:\n",
" if len(line.strip()) > 0:\n",
" number, speaker, text = line.split(\" \", maxsplit=2)\n",
" number = int(number.strip('\"'))\n",
" speaker = speaker.strip('\"').strip()\n",
" text = text.strip('\"').strip()\n",
"\n",
" # return a parsed Line object\n",
" yield {\n",
" \"scriptTitle\": script_title,\n",
" \"number\": number,\n",
" \"speaker\": speaker,\n",
" \"text\": text\n",
" }\n",
"\n",
"# Batch import all Line objects\n",
"with client.batch as batch:\n",
" \n",
" datapath = Path(DATAPATH)\n",
" for filepath in datapath.iterdir():\n",
" if filepath.match(\"*.txt\"):\n",
" for line_obj in parse_script(filepath):\n",
" client.batch.add_data_object(line_obj, \"Line\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ef7abf3d-f18a-437d-b1fb-2a045a52916c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"speaker\": \"VADER\",\n",
" \"text\": \"No. I am your father.\"\n",
"}\n",
"{\n",
" \"speaker\": \"LUKE\",\n",
" \"text\": \"He's my father.\"\n",
"}\n",
"{\n",
" \"speaker\": \"YODA\",\n",
" \"text\": \"Your father he is.\"\n",
"}\n",
"{\n",
" \"speaker\": \"LUKE\",\n",
" \"text\": \"I've accepted the truth that you were once Anakin Skywalker, my father.\"\n",
"}\n",
"{\n",
" \"speaker\": \"LEIA\",\n",
" \"text\": \"Your father?\"\n",
"}\n",
"{\n",
" \"speaker\": \"LUKE\",\n",
" \"text\": \"You told me Vader betrayed and murdered my father.\"\n",
"}\n",
"{\n",
" \"speaker\": \"LUKE\",\n",
" \"text\": \"Master Yoda... is Darth Vader my father?\"\n",
"}\n",
"{\n",
" \"speaker\": \"LUKE\",\n",
" \"text\": \"I know, father.\"\n",
"}\n",
"{\n",
" \"speaker\": \"LUKE\",\n",
" \"text\": \"I found out Darth Vader was my father.\"\n",
"}\n",
"{\n",
" \"speaker\": \"VADER\",\n",
" \"text\": \"Son, come with me.\"\n",
"}\n"
]
}
],
"source": [
"# Do some searches!\n",
"\n",
"# Hybrid fraction controls how keyword and vector searches are combined together\n",
"# Set to 0 for pure keyword\n",
"# Set to 1 for pure vector search\n",
"# Set to 0.5 for a 50/50 mix\n",
"hybrid_fraction = 1\n",
"\n",
"result = client.query.get(\"Line\", [\"speaker\", \"text\"]).with_hybrid(\"Kid, I'm your dad\", hybrid_fraction).do()\n",
"for hit in result[\"data\"][\"Get\"][\"Line\"][0:10]:\n",
" print(json.dumps(hit, indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3280f35f-c35d-4f8e-9f47-806d073aa4ce",
"metadata": {},
"outputs": [],
"source": [
"# If you need to delete your index and start over:\n",
"#client.schema.delete_class(\"Line\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment