Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created May 25, 2024 01:06
Show Gist options
  • Save laksjdjf/dc9eb509a4151777c1d627f6c3329d92 to your computer and use it in GitHub Desktop.
Save laksjdjf/dc9eb509a4151777c1d627f6c3329d92 to your computer and use it in GitHub Desktop.
from PIL import Image
import hpsv2
import torch
class HPSv2:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE", ),
"prompt": ("STRING", {"multiline": True}),
},
}
RETURN_TYPES = ("IMAGE", "STRING")
FUNCTION = "apply"
CATEGORY = "_for_testing"
def apply(self, images, prompt):
images_numpy = (images * 255).numpy().astype('uint8')
images_pillow = [Image.fromarray(image) for image in images_numpy]
scores = []
for image in images_pillow:
score = hpsv2.score(image, prompt, hps_version="v2.1")[0]
scores.append(score)
# 画像とスコアをペアにしてソート
images_scores = sorted(zip(images, scores), key=lambda x: x[1], reverse=True)
# ソートされた結果から画像とスコアを再取得
sorted_images, sorted_scores = zip(*images_scores)
result_string = "\n".join([f"image_{i+1}:{result:4f}" for i, result in enumerate(sorted_scores)])
sorted_images = torch.stack(sorted_images)
return (sorted_images, result_string)
NODE_CLASS_MAPPINGS = {
"HPSv2": HPSv2,
}
__all__ = ["NODE_CLASS_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment