Skip to content

Instantly share code, notes, and snippets.

@Glavin001
Forked from j0yk1ll/embed_score.py
Created May 21, 2024 01:28
Show Gist options
  • Save Glavin001/275bc82536d43cb3821b4ff90e487135 to your computer and use it in GitHub Desktop.
Save Glavin001/275bc82536d43cb3821b4ff90e487135 to your computer and use it in GitHub Desktop.
import base64
from io import BytesIO
from PIL import Image
import torch
import torch.nn as nn
import ollama
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from safetensors.torch import load_file
from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16).to(device)
def generate_image(prompt, image_filename="output.png"):
# https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion
image = diffusion_model(prompt=prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
image.save(image_filename)
return image
def generate_image_embedding_with_clip(image):
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
with torch.no_grad():
inputs = clip_processor(images=image, return_tensors="pt").to(device)
image_features = clip_model.get_image_features(**inputs)
return image_features[0]
def generate_image_embedding_with_dino(image):
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
with torch.no_grad():
inputs = dino_processor(images=image, return_tensors="pt").to(device)
outputs = dino_model(**inputs)
image_features = outputs.last_hidden_state.mean(dim=1)
return image_features[0]
def compare_captions(image, short_caption, long_caption):
short_caption_image = generate_image(short_caption, "short_caption_image.jpg")
long_caption_image = generate_image(long_caption, "long_caption_image.jpg")
image_embedding_dino = generate_image_embedding_with_dino(image)
short_caption_image_embedding_dino = generate_image_embedding_with_dino(short_caption_image)
long_caption_image_embedding_dino = generate_image_embedding_with_dino(long_caption_image)
short_score_dino = calc_cosine_similarity(image_embedding_dino, short_caption_image_embedding_dino)
long_score_dino = calc_cosine_similarity(image_embedding_dino, long_caption_image_embedding_dino)
print(short_score_dino, long_score_dino)
image_embedding_clip = generate_image_embedding_with_clip(image)
short_caption_image_embedding_clip = generate_image_embedding_with_clip(short_caption_image)
long_caption_image_embedding_clip = generate_image_embedding_with_clip(long_caption_image)
short_score_clip = calc_cosine_similarity(image_embedding_clip, short_caption_image_embedding_clip)
long_score_clip = calc_cosine_similarity(image_embedding_clip, long_caption_image_embedding_clip)
print(short_score_clip, long_score_clip)
def calc_cosine_similarity(embedding_1, embedding_2):
cos = nn.CosineSimilarity(dim=0)
sim = cos(embedding_1, embedding_2).item()
return (sim + 1) / 2
if __name__ == "__main__":
image = Image.open("image.jpg")
short_caption = "how to build an industry for dollars"
long_caption = "In the image, there is a small black house with a green roof situated in a grassy area surrounded by trees. The house appears to be under construction or renovation, as there are various tools and materials visible around it, such as a hammer, nails, screws, and wood planks. The presence of these objects indicates that the house is being built or repaired, and the green roof adds a unique and eco-friendly feature to the structure."
compare_captions(image, short_caption, long_caption)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment