-
-
Save zilto/47ae718fce24aad7143d369b4f326010 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# functions from pinecone_module.py | |
from types import ModuleType | |
import numpy as np | |
import pinecone | |
def client_vector_db(vector_db_config: dict) -> ModuleType: | |
"""Instantiate Pinecone client using Environment and API key""" | |
pinecone.init(**vector_db_config) | |
return pinecone | |
def data_objects( | |
ids: list[str], titles: list[str], embeddings: list[np.ndarray], metadata: dict | |
) -> list[tuple]: | |
"""Create valid pinecone objects (index, vector, metadata) tuples for upsert""" | |
assert len(ids) == len(titles) == len(embeddings) | |
properties = [dict(title=title, **metadata) for title in titles] | |
embeddings = [x.tolist() for x in embeddings] | |
return list(zip(ids, embeddings, properties)) | |
def push_to_vector_db( | |
client_vector_db: ModuleType, | |
index_name: str, | |
data_objects: list[tuple], | |
batch_size: int = 100, | |
) -> int: | |
"""Upsert objects to Pinecone index; return the number of objects inserted""" | |
pinecone_index = pinecone.Index(index_name) | |
for i in range(0, len(data_objects), batch_size): | |
i_end = min(i + batch_size, len(data_objects)) | |
pinecone_index.upsert(vectors=data_objects[i:i_end]) | |
return len(data_objects) | |
# from weaviate_module.py | |
import numpy as np | |
import weaviate | |
def client_vector_db(vector_db_config: dict) -> weaviate.Client: | |
"""Instantiate Weaviate client using Environment and API key""" | |
client = weaviate.Client(**vector_db_config) | |
if client.is_live() and client.is_ready(): | |
return client | |
else: | |
raise ConnectionError("Error creating Weaviate client") | |
def data_objects( | |
ids: list[str], titles: list[str], text_contents: list[str], metadata: dict | |
) -> list[dict]: | |
"""Create valid weaviate objects that match the defined schema""" | |
assert len(ids) == len(titles) == len(text_contents) | |
return [ | |
dict(squad_id=id_, title=title, context=context, **metadata) | |
for id_, title, context in zip(ids, titles, text_contents) | |
] | |
def push_to_vector_db( | |
client_vector_db: weaviate.Client, | |
class_name: str, | |
data_objects: list[dict], | |
embeddings: list[np.ndarray], | |
batch_size: int = 100, | |
) -> int: | |
"""Push batch of data objects with their respective embedding to Weaviate. | |
Return number of objects. | |
""" | |
assert len(data_objects) == len(embeddings) | |
with client_vector_db.batch(batch_size=batch_size, dynamic=True) as batch: | |
for i in range(len(data_objects)): | |
batch.add_data_object( | |
data_object=data_objects[i], class_name=class_name, vector=embeddings[i] | |
) | |
return len(data_objects) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment