Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created March 17, 2024 23:32
Show Gist options
  • Save kohya-ss/091304f9927ad6207101b02aa248d36f to your computer and use it in GitHub Desktop.
Save kohya-ss/091304f9927ad6207101b02aa248d36f to your computer and use it in GitHub Desktop.
VAEとTAESDのdecode結果を比較するやつ
# Claude 3 Opus にだいたい書いてもらった
# python vae_vs_taesd.py --image_dir /path/to/image/directory
import os
import argparse
import random
from PIL import Image, ImageTk
import torch
from diffusers import AutoencoderKL, AutoencoderTiny
import tkinter as tk
import numpy as np
# コマンドライン引数のパーサーを設定
parser = argparse.ArgumentParser(description="VAE and TAESD performance comparison")
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing images")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# VAEとTAESDをHuggingFaceから読み込む
# vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
# taesd = AutoencoderKL.from_pretrained("Doggettx/sd-xlarge-taesd")
print("loading VAE...")
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
print("loading TAESD...")
taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl")
vae.to(device)
taesd.to(device)
# 画像ファイルのリストを取得
image_files = [f for f in os.listdir(args.image_dir) if f.endswith(".jpg") or f.endswith(".png")]
# 正解数を初期化
correct_count = 0
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
# 画像の面積が1024x1024より大きい場合、アスペクト比を保ちつつ縮小
width, height = image.size
if width * height > 1024 * 1024:
ratio = (1024 * 1024 / (width * height)) ** 0.5
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height))
# 幅、高さとも8で割り切れるサイズにcrop
width, height = image.size
crop_width = width // 8 * 8
crop_height = height // 8 * 8
left = (width - crop_width) // 2
top = (height - crop_height) // 2
right = left + crop_width
bottom = top + crop_height
image = image.crop((left, top, right, bottom))
return image
def encode_image(image):
# 画像をVAEでencodeしlatentsに変換
# image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
image_tensor = torch.from_numpy(np.array(image)).float() / 127.5 - 1
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
image_tensor = image_tensor.to(device)
with torch.no_grad():
latents = vae.encode(image_tensor).latent_dist.sample().squeeze(0)
return latents
def decode_latents(latents, model):
# latentsを画像にdecode
with torch.no_grad():
decoded_image = model.decode(latents.unsqueeze(0)).sample
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
decoded_image = decoded_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
return decoded_image
def show_results():
# 正解率を計算して表示
accuracy = correct_count / len(image_files)
result_label.config(text=f"Accuracy: {accuracy:.2f}")
def on_select(selected_side):
global correct_count
if (selected_side == "LEFT" and correct_side == "LEFT") or (selected_side == "RIGHT" and correct_side == "RIGHT"):
correct_count += 1
if current_index < len(image_files) - 1:
show_next_image()
else:
show_results()
def show_next_image():
global current_index, correct_side
current_index += 1
# 画像を読み込んでVAEでencodeし、VAEとTAESDでdecode
image_path = os.path.join(args.image_dir, image_files[current_index])
image = load_image(image_path)
latents = encode_image(image)
vae_decoded = decode_latents(latents, vae)
taesd_decoded = decode_latents(latents * vae.config.scaling_factor, taesd)
# decodedした画像をランダムに左右に配置
if random.choice([True, False]):
left_image = vae_decoded
right_image = taesd_decoded
correct_side = "LEFT"
else:
left_image = taesd_decoded
right_image = vae_decoded
correct_side = "RIGHT"
# 画像を表示
left_photo = ImageTk.PhotoImage(left_image)
right_photo = ImageTk.PhotoImage(right_image)
left_label.config(image=left_photo)
right_label.config(image=right_photo)
left_label.image = left_photo
right_label.image = right_photo
question_label.config(text="Which side is VAE?")
# TkInterのGUIを設定
window = tk.Tk()
window.title("VAE vs TAESD")
left_label = tk.Label(window)
left_label.pack(side=tk.LEFT)
right_label = tk.Label(window)
right_label.pack(side=tk.RIGHT)
question_label = tk.Label(window, text="")
question_label.pack()
left_button = tk.Button(window, text="LEFT", command=lambda: on_select("LEFT"))
left_button.pack(side=tk.LEFT)
right_button = tk.Button(window, text="RIGHT", command=lambda: on_select("RIGHT"))
right_button.pack(side=tk.RIGHT)
result_label = tk.Label(window, text="")
result_label.pack()
current_index = -1
correct_side = ""
show_next_image()
window.bind("<Key>", lambda event: window.destroy())
window.mainloop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment