Skip to content

Instantly share code, notes, and snippets.

@iamironz
Created October 29, 2023 11:24
Show Gist options
  • Save iamironz/411893138a58d54d7a054d65f4eee00f to your computer and use it in GitHub Desktop.
Save iamironz/411893138a58d54d7a054d65f4eee00f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import sys
import os
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
def log_message(message):
print(message)
with open("log.txt", "a") as log_file:
log_file.write(f"{message}\n")
device = "cuda" if torch.cuda.is_available() else "cpu"
log_message(f"Device: {device}")
log_message("Loading processor...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
log_message("Loading model...")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
log_message("Moving model to device...")
model.to(device)
log_message("Model loaded")
def process_images(image_paths):
log_message(f"Processing {len(image_paths)} images")
images = [Image.open(image_path) for image_path in image_paths]
inputs = processor(images=images, return_tensors="pt").to(device)
generated_ids = model.generate(**inputs)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_texts
def main(input_path):
log_message(f"Input path: {input_path}")
image_paths = []
if os.path.isdir(input_path):
log_message("Input path is a directory")
for root, dirs, files in os.walk(input_path):
for file in files:
if file.startswith('.') or file == '.DS_Store':
continue
image_paths.append(os.path.join(root, file))
captions = process_images(image_paths)
for path, caption in zip(image_paths, captions):
log_message(f"{path}: {caption.strip()}")
elif os.path.isfile(input_path) and not input_path.endswith('.DS_Store'):
log_message("Input path is a file")
caption = process_images([input_path])[0]
log_message(f"{input_path}: {caption.strip()}")
else:
log_message("Invalid input path")
if __name__ == "__main__":
if len(sys.argv) != 2:
log_message("Usage: script.py <path_to_file_or_folder>")
else:
main(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment