Skip to content

Instantly share code, notes, and snippets.

@zilto
Last active July 16, 2023 14:51
Show Gist options
  • Save zilto/47ae718fce24aad7143d369b4f326010 to your computer and use it in GitHub Desktop.
Save zilto/47ae718fce24aad7143d369b4f326010 to your computer and use it in GitHub Desktop.
# functions from pinecone_module.py
from types import ModuleType
import numpy as np
import pinecone
def client_vector_db(vector_db_config: dict) -> ModuleType:
"""Instantiate Pinecone client using Environment and API key"""
pinecone.init(**vector_db_config)
return pinecone
def data_objects(
ids: list[str], titles: list[str], embeddings: list[np.ndarray], metadata: dict
) -> list[tuple]:
"""Create valid pinecone objects (index, vector, metadata) tuples for upsert"""
assert len(ids) == len(titles) == len(embeddings)
properties = [dict(title=title, **metadata) for title in titles]
embeddings = [x.tolist() for x in embeddings]
return list(zip(ids, embeddings, properties))
def push_to_vector_db(
client_vector_db: ModuleType,
index_name: str,
data_objects: list[tuple],
batch_size: int = 100,
) -> int:
"""Upsert objects to Pinecone index; return the number of objects inserted"""
pinecone_index = pinecone.Index(index_name)
for i in range(0, len(data_objects), batch_size):
i_end = min(i + batch_size, len(data_objects))
pinecone_index.upsert(vectors=data_objects[i:i_end])
return len(data_objects)
# from weaviate_module.py
import numpy as np
import weaviate
def client_vector_db(vector_db_config: dict) -> weaviate.Client:
"""Instantiate Weaviate client using Environment and API key"""
client = weaviate.Client(**vector_db_config)
if client.is_live() and client.is_ready():
return client
else:
raise ConnectionError("Error creating Weaviate client")
def data_objects(
ids: list[str], titles: list[str], text_contents: list[str], metadata: dict
) -> list[dict]:
"""Create valid weaviate objects that match the defined schema"""
assert len(ids) == len(titles) == len(text_contents)
return [
dict(squad_id=id_, title=title, context=context, **metadata)
for id_, title, context in zip(ids, titles, text_contents)
]
def push_to_vector_db(
client_vector_db: weaviate.Client,
class_name: str,
data_objects: list[dict],
embeddings: list[np.ndarray],
batch_size: int = 100,
) -> int:
"""Push batch of data objects with their respective embedding to Weaviate.
Return number of objects.
"""
assert len(data_objects) == len(embeddings)
with client_vector_db.batch(batch_size=batch_size, dynamic=True) as batch:
for i in range(len(data_objects)):
batch.add_data_object(
data_object=data_objects[i], class_name=class_name, vector=embeddings[i]
)
return len(data_objects)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment