Skip to content

Instantly share code, notes, and snippets.

@ArnoFrost
Created November 2, 2023 08:20
Show Gist options
  • Save ArnoFrost/21c836cf23a99c7872a30a619b97b3a8 to your computer and use it in GitHub Desktop.
Save ArnoFrost/21c836cf23a99c7872a30a619b97b3a8 to your computer and use it in GitHub Desktop.
尝试open_ai_clip demo
import os
import pickle
import faiss
import numpy as np
import torch
from PIL import Image
from googletrans import Translator
from transformers import CLIPProcessor, CLIPModel
from scripts.ai.VectorDatabase import VectorDatabase
# 常量定义
# FIXME 调整图片目录
DIRECTORY_PATH = "/xxxx/训练"
DB_PATH = os.path.join(DIRECTORY_PATH, "vector_db.pkl")
class VectorDatabase:
db_file = None
def __init__(self, db_file: str = "vector_database.pkl", dimension=768):
self.dimension = dimension
self.db_file = db_file
if os.path.exists(self.db_file):
self.load_from_file()
else:
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(dimension))
self.file_paths = []
def add_vector(self, vector, file_path):
"""添加一个向量及其关联的文件路径到数据库,或者更新现有的向量"""
if self.file_path_exists(file_path):
idx = self.file_paths.index(file_path)
self.index.remove_ids(np.array([idx]))
self.index.add_with_ids(np.array([vector], dtype=np.float32), np.array([idx]))
else:
idx = len(self.file_paths)
self.index.add_with_ids(np.array([vector], dtype=np.float32), np.array([idx]))
self.file_paths.append(file_path)
def search_vector(self, vector, k=1):
"""在数据库中查询最近的k个向量"""
distances, indices = self.index.search(np.array([vector], dtype=np.float32), k)
return [(self.file_paths[i], distances[0][j]) for j, i in enumerate(indices[0])]
def save_to_file(self):
with open(self.db_file, 'wb') as f:
pickle.dump({'index': self.index, 'file_paths': self.file_paths}, f)
def load_from_file(self):
with open(self.db_file, 'rb') as f:
data = pickle.load(f)
self.index = data['index']
self.file_paths = data['file_paths']
def file_path_exists(self, file_path):
"""检查给定的文件路径是否已经存在于数据库中"""
return file_path in self.file_paths
class ClipSearchEngine:
def __init__(self):
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.translator = Translator()
def translate_to_english(self, text):
translation = self.translator.translate(text, src='zh-CN', dest='en')
return translation[0].text
def get_image_vector(self, image_path):
image = Image.open(image_path)
inputs = self.processor(text=["dummy"], images=image, return_tensors="pt", padding=True)
outputs = self.model(**inputs)
return outputs.image_embeds.squeeze().cpu().detach().numpy()
def get_text_vector(self, text_description: list[str]):
dummy_image = torch.zeros(1, 3, 224, 224)
inputs = self.processor(text=text_description, images=dummy_image, return_tensors="pt", padding=True)
outputs = self.model(**inputs)
return outputs.text_embeds.squeeze().cpu().detach().numpy()
def initialize_or_load_db(db_path):
db = VectorDatabase(db_file=db_path, dimension=768)
if not os.path.exists(db_path):
db.save_to_file()
return db
def save_image_vectors(engine: ClipSearchEngine, db: VectorDatabase, directory_path: str):
for filename in os.listdir(directory_path):
if filename.endswith((".jpg", ".png", ".jpeg", ".webp")):
file_path = os.path.join(directory_path, filename)
vector = engine.get_image_vector(file_path)
db.add_vector(vector, file_path)
def search_images_by_text(engine: ClipSearchEngine, text_description: list[str], db: VectorDatabase,
min_similarity_score=0):
text_vector = engine.get_text_vector(text_description)
results = db.search_vector(text_vector, k=len(db.file_paths))
# 过滤同时满足分数条件的选项
filtered_results = [result for result in results if result[1] <= min_similarity_score]
return filtered_results
if __name__ == '__main__':
search_engine = ClipSearchEngine()
db = initialize_or_load_db(DB_PATH)
save_image_vectors(search_engine, db, DIRECTORY_PATH)
topK = 10
max_similarity_score = 1.7
text_description = ["车辆"]
translated_descriptions = search_engine.translate_to_english(text_description)
print(translated_descriptions)
results = search_images_by_text(search_engine, translated_descriptions, db, max_similarity_score)
for path, score in results[:topK]:
print(f"Image: {path}, Similarity: {score:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment