Skip to content

Instantly share code, notes, and snippets.

@eebmagic
Last active November 9, 2023 22:41
Show Gist options
  • Save eebmagic/7f560dcc37984752a1e2815e486abd74 to your computer and use it in GitHub Desktop.
Save eebmagic/7f560dcc37984752a1e2815e486abd74 to your computer and use it in GitHub Desktop.
An interface for chromaDB that checks that documents are only added when there is some data change (doc or metadata). Should also minimze embedding function calls (although, still gets called for metadata changes that won't change embedding results).
import hashlib
import json
import threading
import math
import time
class SafeInterface():
def __init__(self, collection, batchSize=200, threaded=True):
self.col = collection
self.batchSize = batchSize
self.threaded = threaded
def __str__(self):
output = ""
output += f"SafeInterface:\n"
output += f"\tCollection: {self.col}\n"
output += f"\tBatch size: {self.batchSize}\n"
return output
def _addBatch(self, i, threadCount, start, end, ids, docs, metas=[]):
if metas:
self.col.add(
ids=ids[start:end],
documents=docs[start:end],
metadatas=metas[start:end]
)
else:
self.col.add(
ids=ids[start:end],
documents=docs[start:end],
)
print(f"Finished thread {i} / {threadCount}")
def addInBatches(self, ids, docs, metas=[]):
numItems = len(ids)
threads = []
threadNum = math.ceil(numItems / self.batchSize)
iters = 0
for start in range(0, numItems, self.batchSize):
end = min(start + self.batchSize, numItems)
thread = threading.Thread(target=self._addBatch, args=(iters+1, threadNum, start, end, ids, docs, metas))
threads.append(thread)
thread.start()
iters += 1
for thread in threads:
thread.join()
return len(threads)
def add(self, ids, documents, metadatas=None):
addStart = time.time()
if type(ids) == str:
ids = [ids]
documents = [documents]
metadatas=[metadatas]
existing = self.col.get(ids=ids, include=["metadatas", "documents"])
newIds = list(set(ids) - set(existing['ids']))
# Add new docs
if len(newIds) > 0:
newDocs = []
newMetas = []
if metadatas:
for idx, doc, meta in zip(ids, documents, metadatas):
if idx in newIds:
newDocs.append(doc)
newMetas.append(meta)
else:
for idx, doc in zip(ids, documents):
if idx in newIds:
newDocs.append(doc)
print(f"Inserting {len(newIds)} new documents")
iters = self.addInBatches(newIds, newDocs, newMetas)
print(f"Added in {iters} batches of size {self.batchSize}")
# Get object hashes for existing entries
newVals = {}
if metadatas:
for idx, doc, meta in zip(ids, documents, metadatas):
if idx in existing:
docHash = hashlib.sha256(doc.encode()),
metaString = json.dumps(meta, sort_keys=True)
metaHash = hashlib.sha256(metaString),
newVals[idx] = {
'doc': doc,
'docHash': docHash,
'meta': meta,
'metaHash': metaHash
}
else:
for idx, doc in zip(ids, documents):
if idx in existing:
docHash = hashlib.sha256(doc.encode()),
newVals[idx] = {
'doc': doc,
'docHash': docHash,
'meta': None,
'metaHash': None
}
# Find docs that don't match existing entries
updatedIds = []
updatedDocs = []
updatedMetas = []
for idx, doc, meta in zip(existing['ids'], existing['documents'], existing['metadatas']):
existingDocHash = hashlib.sha256(doc.encode())
metaString = json.dumps(meta, sort_keys=True)
existingMetaHash = hashlib.sha256(metaString.encode())
newData = newVals.get(idx, None)
if newData:
docMatch = existingDocHash != newData['docHash']
metaMatch = existingMetaHash != newData['metaHash']
if not docMatch or not metaMatch:
updatedIds.append(idx)
updatedDocs.append(newData['doc'])
updatedMetas.append(newData['meta'])
# Update new changed entries
if len(updatedIds) > 0:
print(f"Updating {len(updatedIds)} documents")
iters = self.addInBatches(updatedIds, updatedDocs, updatedMetas)
print(f"Updated in {iters} batches of size {self.batchSize}")
ignored = len(ids) - len(newIds) - len(updatedIds)
if ignored != 0:
print(f"Ignored {ignored} documents because they already exist in the collection")
print(f"Interface finished add process in {time.time() - addStart} seconds")
if __name__ == '__main__':
import chromadb
from CustomEmbedding import MyEmbeddingFunction
client = chromadb.EphemeralClient()
col = client.create_collection(name="test", embedding_function=MyEmbeddingFunction())
inf = SafeInterface(col)
print(f"Collection has: {col.count()} docs")
inf.add(
ids='1',
documents='first document'
)
print(f"Collection has: {col.count()} docs")
inf.add(
ids=['1'],
documents=['first document']
)
print(f"Collection has: {col.count()} docs")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment