Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Last active March 18, 2024 23:17
Show Gist options
  • Save kohya-ss/31edb9e1f3bde12a87228c82b7c38741 to your computer and use it in GitHub Desktop.
Save kohya-ss/31edb9e1f3bde12a87228c82b7c38741 to your computer and use it in GitHub Desktop.
VAEとTAESDのdecode結果を比較するやつ、Gradio版
# Claude 3 Opus とめっちゃやり取りして動くようになった
# python vae_vs_taesd_gradio.py --image_dir /path/to/image/directory
import os
import argparse
import random
from PIL import Image
import torch
from diffusers import AutoencoderKL, AutoencoderTiny
import numpy as np
import gradio as gr
# コマンドライン引数のパーサーを設定
parser = argparse.ArgumentParser(description="VAE and TAESD performance comparison")
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing images")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# VAEとTAESDをHuggingFaceから読み込む
# vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
# taesd = AutoencoderKL.from_pretrained("Doggettx/sd-xlarge-taesd")
print("loading VAE...")
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
print("loading TAESD...")
taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl")
vae.to(device)
taesd.to(device)
# 画像ファイルのリストを取得
image_files = [f for f in os.listdir(args.image_dir) if f.endswith(".jpg") or f.endswith(".png")]
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
# 画像の面積が1024x1024より大きい場合、アスペクト比を保ちつつ縮小
width, height = image.size
if width * height > 1024 * 1024:
ratio = (1024 * 1024 / (width * height)) ** 0.5
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height))
# 幅、高さとも8で割り切れるサイズにcrop
width, height = image.size
crop_width = width // 8 * 8
crop_height = height // 8 * 8
left = (width - crop_width) // 2
top = (height - crop_height) // 2
right = left + crop_width
bottom = top + crop_height
image = image.crop((left, top, right, bottom))
return image
def encode_image(image):
# 画像をVAEでencodeしlatentsに変換
# image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
image_tensor = torch.from_numpy(np.array(image)).float() / 127.5 - 1
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
latents = vae.encode(image_tensor).latent_dist.sample().squeeze(0)
return latents
def decode_latents(latents, model):
# latentsを画像にdecode
with torch.no_grad():
decoded_image = model.decode(latents.unsqueeze(0)).sample
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
decoded_image = decoded_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
return decoded_image
def compare_vae_taesd(left_button=None, right_button=None):
global current_index, correct_count, correct_side
# ボタンがクリックされたかどうかを判定
button_clicked = left_button is not None or right_button is not None
if button_clicked:
if (left_button and correct_side == "LEFT") or (right_button and correct_side == "RIGHT"):
correct_count += 1
current_index += 1
if current_index < len(image_files):
# 画像を読み込んでVAEでencodeし、VAEとTAESDでdecode
image_path = os.path.join(args.image_dir, image_files[current_index])
image = load_image(image_path)
latents = encode_image(image)
vae_decoded = decode_latents(latents, vae)
# taesd_decoded = decode_latents(latents, taesd)
taesd_decoded = decode_latents(latents * vae.config.scaling_factor, taesd)
# decodedした画像をランダムに左右に配置
if random.choice([True, False]):
left_image = vae_decoded
right_image = taesd_decoded
correct_side = "LEFT"
else:
left_image = taesd_decoded
right_image = vae_decoded
correct_side = "RIGHT"
return left_image, right_image, f"Which side is VAE? ({current_index+1}/{len(image_files)})"
else:
accuracy = correct_count / len(image_files)
return None, None, f"Accuracy: {accuracy:.2f}"
# 最初の画像を表示するための関数
def show_first_image():
global current_index, correct_side
current_index = 0
left_image, right_image, question = compare_vae_taesd()
return left_image, right_image, question
# Gradioインターフェースの設定
with gr.Blocks() as demo:
gr.Markdown("## VAE vs TAESD Comparison")
with gr.Row():
left_image = gr.Image()
right_image = gr.Image()
with gr.Row():
left_button = gr.Button("LEFT")
right_button = gr.Button("RIGHT")
question_label = gr.Textbox(label="Question")
left_button.click(compare_vae_taesd, inputs=[left_button, right_button], outputs=[left_image, right_image, question_label])
right_button.click(compare_vae_taesd, inputs=[left_button, right_button], outputs=[left_image, right_image, question_label])
current_index = 0
correct_count = 0
demo.load(show_first_image, outputs=[left_image, right_image, question_label])
demo.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment