Created
December 4, 2024 09:13
-
-
Save PrashantSaikia/368a32ecb9efbb65ec7c78dae6a41059 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from chromadb import PersistentClient | |
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
from chromadb.utils.data_loaders import ImageLoader | |
from PIL import Image | |
import fitz # PyMuPDF | |
def save_images_from_pdfs(pdf_folder, output_folder): | |
"""Convert each page of all PDF files in a folder to images and save them.""" | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
for pdf_file in os.listdir(pdf_folder): | |
if pdf_file.endswith('.pdf'): | |
pdf_path = os.path.join(pdf_folder, pdf_file) | |
doc = fitz.open(pdf_path) | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
pix = page.get_pixmap() | |
# Save image as PNG | |
output_file = os.path.join(output_folder, f"{os.path.splitext(pdf_file)[0]}_page_{page_num + 1}.png") | |
pix.save(output_file) | |
print(f"PDFs have been converted to images in folder: {output_folder}") | |
def initialize_vector_db(db_path): | |
"""Initialize and return the ChromaDB client and image vector database.""" | |
client = PersistentClient(path=db_path) | |
image_loader = ImageLoader() | |
clip = OpenCLIPEmbeddingFunction() | |
image_vdb = client.get_or_create_collection(name="multimodal_rag", embedding_function=clip, data_loader=image_loader) | |
return image_vdb | |
def add_images_to_db(image_folder, image_vdb): | |
"""Add images from a folder to the vector database.""" | |
ids = [] | |
uris = [] | |
for i, filename in enumerate(sorted(os.listdir(image_folder))): | |
if filename.endswith('.png'): | |
file_path = os.path.join(image_folder, filename) | |
ids.append(str(i)) | |
uris.append(file_path) | |
image_vdb.add(ids=ids, uris=uris) | |
print("Images added to the database.") | |
def validate_vector_db(image_vdb): | |
"""Validate the vector database by counting the entries.""" | |
count = image_vdb.count() | |
print(f"Number of images in the database: {count}") | |
return count | |
if __name__ == "__main__": | |
# Folder paths | |
pdf_folder = "./TariffDocs" # Folder containing PDFs | |
output_folder = "./images" # Folder to save the images | |
db_path = "./image_vdb" # Path for the vector database | |
# Step 1: Convert PDFs to images | |
save_images_from_pdfs(pdf_folder, output_folder) | |
# Step 2: Initialize the vector database | |
image_vdb = initialize_vector_db(db_path) | |
# Step 3: Add images to the database | |
add_images_to_db(output_folder, image_vdb) | |
# Step 4: Validate the database | |
validate_vector_db(image_vdb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment