|
# written by claude 3.5 sonnet |
|
|
|
import os |
|
import json |
|
import hashlib |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import shutil |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
#print("device = " + device) |
|
|
|
def get_embedding(text, model): |
|
return model.encode([text])[0] |
|
|
|
def classify_text(input_text, categories, model): |
|
input_embedding = get_embedding(input_text, model) |
|
category_embeddings = [get_embedding(cat, model) for cat in categories] |
|
|
|
similarities = cosine_similarity([input_embedding], category_embeddings)[0] |
|
most_similar_index = np.argmax(similarities) |
|
|
|
return categories[most_similar_index], similarities[most_similar_index] |
|
|
|
# Load a pre-trained model |
|
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
def compute_sha256(file_path): |
|
sha256_hash = hashlib.sha256() |
|
with open(file_path, "rb") as f: |
|
for byte_block in iter(lambda: f.read(4096), b""): |
|
sha256_hash.update(byte_block) |
|
return sha256_hash.hexdigest() |
|
|
|
def load_cache(cache_file): |
|
if os.path.exists(cache_file): |
|
with open(cache_file, 'r') as f: |
|
return json.load(f) |
|
return {} |
|
|
|
def save_cache(cache_file, cache): |
|
with open(cache_file, 'w') as f: |
|
json.dump(cache, f, indent=2) |
|
|
|
def list_folders(directory): |
|
""" |
|
List all folders in the specified directory. |
|
|
|
Args: |
|
directory (str): The path to the directory to search. |
|
|
|
Returns: |
|
list: A list of folder names in the specified directory. |
|
""" |
|
try: |
|
# Get all items in the directory |
|
all_items = os.listdir(directory) |
|
|
|
# Filter for only directories (folders) |
|
folders = [item for item in all_items if os.path.isdir(os.path.join(directory, item))] |
|
|
|
return folders |
|
except FileNotFoundError: |
|
print(f"Error: The directory '{directory}' was not found.") |
|
return [] |
|
except PermissionError: |
|
print(f"Error: Permission denied to access the directory '{directory}'.") |
|
return [] |
|
except Exception as e: |
|
print(f"An error occurred: {str(e)}") |
|
return [] |
|
|
|
def organize_images(categories, folder_path, cache_file='image_descriptions_cache.json'): |
|
# Load the model and processor |
|
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") #.to(device) |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco").to(device) |
|
|
|
|
|
# Load cache |
|
cache = load_cache(cache_file) |
|
|
|
# Get all image files in the folder |
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') |
|
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(image_extensions)] |
|
|
|
# Process each image |
|
for image_file in image_files: |
|
image_path = os.path.join(folder_path, image_file) |
|
file_hash = compute_sha256(image_path) |
|
|
|
if file_hash in cache: |
|
description = cache[file_hash] |
|
print(f"Image: {image_file} (cached)") |
|
else: |
|
# Open and preprocess the image |
|
image = Image.open(image_path) |
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
|
|
# Generate the image description |
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
pixel_values=inputs["pixel_values"], |
|
max_length=50, |
|
num_beams=4, |
|
num_return_sequences=1, |
|
) |
|
|
|
# Decode the generated text |
|
with torch.no_grad(): |
|
description = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
# Cache the result |
|
cache[file_hash] = description |
|
save_cache(cache_file, cache) |
|
print(f"Image: {image_file} (new)") |
|
|
|
print(f"Description: {description}") |
|
|
|
classified_category, similarity_score = classify_text(description, categories, sentence_model) |
|
|
|
print(f"Classified as: {classified_category}") |
|
print(f"Similarity score: {similarity_score:.4f}") |
|
print() |
|
|
|
shutil.copyfile(image_path, folder_path + "/" + classified_category + "/" + image_file) |
|
|
|
# Save updated cache |
|
save_cache(cache_file, cache) |
|
|
|
if __name__ == "__main__": |
|
folder_path = "/path/to/your/pictures" |
|
labels = list_folders(folder_path) |
|
organize_images(labels, folder_path) |
Really useful stuff, thanks!