-
-
Save ArnoFrost/21c836cf23a99c7872a30a619b97b3a8 to your computer and use it in GitHub Desktop.
尝试open_ai_clip demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import faiss | |
import numpy as np | |
import torch | |
from PIL import Image | |
from googletrans import Translator | |
from transformers import CLIPProcessor, CLIPModel | |
from scripts.ai.VectorDatabase import VectorDatabase | |
# 常量定义 | |
# FIXME 调整图片目录 | |
DIRECTORY_PATH = "/xxxx/训练" | |
DB_PATH = os.path.join(DIRECTORY_PATH, "vector_db.pkl") | |
class VectorDatabase: | |
db_file = None | |
def __init__(self, db_file: str = "vector_database.pkl", dimension=768): | |
self.dimension = dimension | |
self.db_file = db_file | |
if os.path.exists(self.db_file): | |
self.load_from_file() | |
else: | |
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(dimension)) | |
self.file_paths = [] | |
def add_vector(self, vector, file_path): | |
"""添加一个向量及其关联的文件路径到数据库,或者更新现有的向量""" | |
if self.file_path_exists(file_path): | |
idx = self.file_paths.index(file_path) | |
self.index.remove_ids(np.array([idx])) | |
self.index.add_with_ids(np.array([vector], dtype=np.float32), np.array([idx])) | |
else: | |
idx = len(self.file_paths) | |
self.index.add_with_ids(np.array([vector], dtype=np.float32), np.array([idx])) | |
self.file_paths.append(file_path) | |
def search_vector(self, vector, k=1): | |
"""在数据库中查询最近的k个向量""" | |
distances, indices = self.index.search(np.array([vector], dtype=np.float32), k) | |
return [(self.file_paths[i], distances[0][j]) for j, i in enumerate(indices[0])] | |
def save_to_file(self): | |
with open(self.db_file, 'wb') as f: | |
pickle.dump({'index': self.index, 'file_paths': self.file_paths}, f) | |
def load_from_file(self): | |
with open(self.db_file, 'rb') as f: | |
data = pickle.load(f) | |
self.index = data['index'] | |
self.file_paths = data['file_paths'] | |
def file_path_exists(self, file_path): | |
"""检查给定的文件路径是否已经存在于数据库中""" | |
return file_path in self.file_paths | |
class ClipSearchEngine: | |
def __init__(self): | |
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
self.translator = Translator() | |
def translate_to_english(self, text): | |
translation = self.translator.translate(text, src='zh-CN', dest='en') | |
return translation[0].text | |
def get_image_vector(self, image_path): | |
image = Image.open(image_path) | |
inputs = self.processor(text=["dummy"], images=image, return_tensors="pt", padding=True) | |
outputs = self.model(**inputs) | |
return outputs.image_embeds.squeeze().cpu().detach().numpy() | |
def get_text_vector(self, text_description: list[str]): | |
dummy_image = torch.zeros(1, 3, 224, 224) | |
inputs = self.processor(text=text_description, images=dummy_image, return_tensors="pt", padding=True) | |
outputs = self.model(**inputs) | |
return outputs.text_embeds.squeeze().cpu().detach().numpy() | |
def initialize_or_load_db(db_path): | |
db = VectorDatabase(db_file=db_path, dimension=768) | |
if not os.path.exists(db_path): | |
db.save_to_file() | |
return db | |
def save_image_vectors(engine: ClipSearchEngine, db: VectorDatabase, directory_path: str): | |
for filename in os.listdir(directory_path): | |
if filename.endswith((".jpg", ".png", ".jpeg", ".webp")): | |
file_path = os.path.join(directory_path, filename) | |
vector = engine.get_image_vector(file_path) | |
db.add_vector(vector, file_path) | |
def search_images_by_text(engine: ClipSearchEngine, text_description: list[str], db: VectorDatabase, | |
min_similarity_score=0): | |
text_vector = engine.get_text_vector(text_description) | |
results = db.search_vector(text_vector, k=len(db.file_paths)) | |
# 过滤同时满足分数条件的选项 | |
filtered_results = [result for result in results if result[1] <= min_similarity_score] | |
return filtered_results | |
if __name__ == '__main__': | |
search_engine = ClipSearchEngine() | |
db = initialize_or_load_db(DB_PATH) | |
save_image_vectors(search_engine, db, DIRECTORY_PATH) | |
topK = 10 | |
max_similarity_score = 1.7 | |
text_description = ["车辆"] | |
translated_descriptions = search_engine.translate_to_english(text_description) | |
print(translated_descriptions) | |
results = search_images_by_text(search_engine, translated_descriptions, db, max_similarity_score) | |
for path, score in results[:topK]: | |
print(f"Image: {path}, Similarity: {score:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment