Created
March 17, 2024 23:32
-
-
Save kohya-ss/091304f9927ad6207101b02aa248d36f to your computer and use it in GitHub Desktop.
VAEとTAESDのdecode結果を比較するやつ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Claude 3 Opus にだいたい書いてもらった | |
# python vae_vs_taesd.py --image_dir /path/to/image/directory | |
import os | |
import argparse | |
import random | |
from PIL import Image, ImageTk | |
import torch | |
from diffusers import AutoencoderKL, AutoencoderTiny | |
import tkinter as tk | |
import numpy as np | |
# コマンドライン引数のパーサーを設定 | |
parser = argparse.ArgumentParser(description="VAE and TAESD performance comparison") | |
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing images") | |
args = parser.parse_args() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# VAEとTAESDをHuggingFaceから読み込む | |
# vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
# taesd = AutoencoderKL.from_pretrained("Doggettx/sd-xlarge-taesd") | |
print("loading VAE...") | |
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae") | |
print("loading TAESD...") | |
taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl") | |
vae.to(device) | |
taesd.to(device) | |
# 画像ファイルのリストを取得 | |
image_files = [f for f in os.listdir(args.image_dir) if f.endswith(".jpg") or f.endswith(".png")] | |
# 正解数を初期化 | |
correct_count = 0 | |
def load_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
# 画像の面積が1024x1024より大きい場合、アスペクト比を保ちつつ縮小 | |
width, height = image.size | |
if width * height > 1024 * 1024: | |
ratio = (1024 * 1024 / (width * height)) ** 0.5 | |
new_width = int(width * ratio) | |
new_height = int(height * ratio) | |
image = image.resize((new_width, new_height)) | |
# 幅、高さとも8で割り切れるサイズにcrop | |
width, height = image.size | |
crop_width = width // 8 * 8 | |
crop_height = height // 8 * 8 | |
left = (width - crop_width) // 2 | |
top = (height - crop_height) // 2 | |
right = left + crop_width | |
bottom = top + crop_height | |
image = image.crop((left, top, right, bottom)) | |
return image | |
def encode_image(image): | |
# 画像をVAEでencodeしlatentsに変換 | |
# image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 | |
image_tensor = torch.from_numpy(np.array(image)).float() / 127.5 - 1 | |
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) | |
image_tensor = image_tensor.to(device) | |
with torch.no_grad(): | |
latents = vae.encode(image_tensor).latent_dist.sample().squeeze(0) | |
return latents | |
def decode_latents(latents, model): | |
# latentsを画像にdecode | |
with torch.no_grad(): | |
decoded_image = model.decode(latents.unsqueeze(0)).sample | |
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1) | |
decoded_image = decoded_image.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8)) | |
return decoded_image | |
def show_results(): | |
# 正解率を計算して表示 | |
accuracy = correct_count / len(image_files) | |
result_label.config(text=f"Accuracy: {accuracy:.2f}") | |
def on_select(selected_side): | |
global correct_count | |
if (selected_side == "LEFT" and correct_side == "LEFT") or (selected_side == "RIGHT" and correct_side == "RIGHT"): | |
correct_count += 1 | |
if current_index < len(image_files) - 1: | |
show_next_image() | |
else: | |
show_results() | |
def show_next_image(): | |
global current_index, correct_side | |
current_index += 1 | |
# 画像を読み込んでVAEでencodeし、VAEとTAESDでdecode | |
image_path = os.path.join(args.image_dir, image_files[current_index]) | |
image = load_image(image_path) | |
latents = encode_image(image) | |
vae_decoded = decode_latents(latents, vae) | |
taesd_decoded = decode_latents(latents * vae.config.scaling_factor, taesd) | |
# decodedした画像をランダムに左右に配置 | |
if random.choice([True, False]): | |
left_image = vae_decoded | |
right_image = taesd_decoded | |
correct_side = "LEFT" | |
else: | |
left_image = taesd_decoded | |
right_image = vae_decoded | |
correct_side = "RIGHT" | |
# 画像を表示 | |
left_photo = ImageTk.PhotoImage(left_image) | |
right_photo = ImageTk.PhotoImage(right_image) | |
left_label.config(image=left_photo) | |
right_label.config(image=right_photo) | |
left_label.image = left_photo | |
right_label.image = right_photo | |
question_label.config(text="Which side is VAE?") | |
# TkInterのGUIを設定 | |
window = tk.Tk() | |
window.title("VAE vs TAESD") | |
left_label = tk.Label(window) | |
left_label.pack(side=tk.LEFT) | |
right_label = tk.Label(window) | |
right_label.pack(side=tk.RIGHT) | |
question_label = tk.Label(window, text="") | |
question_label.pack() | |
left_button = tk.Button(window, text="LEFT", command=lambda: on_select("LEFT")) | |
left_button.pack(side=tk.LEFT) | |
right_button = tk.Button(window, text="RIGHT", command=lambda: on_select("RIGHT")) | |
right_button.pack(side=tk.RIGHT) | |
result_label = tk.Label(window, text="") | |
result_label.pack() | |
current_index = -1 | |
correct_side = "" | |
show_next_image() | |
window.bind("<Key>", lambda event: window.destroy()) | |
window.mainloop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment