Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Last active November 7, 2023 20:24
Show Gist options
  • Save rockerBOO/92ca379be8d2fb0da7c9c4cba35c9f43 to your computer and use it in GitHub Desktop.
Save rockerBOO/92ca379be8d2fb0da7c9c4cba35c9f43 to your computer and use it in GitHub Desktop.
# Original from https://gist.github.com/Poiuytrezay1/db6b98672675456bed39d45077d44179
# Credit to Poiuytrezay1
import argparse
import os
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import tqdm
from PIL import Image
from torchvision import transforms
import library.model_util as model_util
import library.sdxl_train_util as sdxl_train_util
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img, image.info
def process_images_group(vae, images_group):
with torch.no_grad():
# Stack the tensors from the same size group
img_tensors = torch.stack(images_group, dim=0).to(vae.device)
# Encode and decode the images
latents = vae.encode(img_tensors).latent_dist.sample()
return latents
def process_latents_from_images(vae, input_file_or_dir, output_dir, args):
if args.consistency_decoder:
from consistencydecoder import ConsistencyDecoder
decoder_consistency = ConsistencyDecoder(device=vae.device)
input = Path(input_file_or_dir)
output = Path(output_dir)
os.makedirs(str(output.absolute()), exist_ok=True)
if input.is_dir():
image_files = [
file
for file in input.iterdir()
if file.suffix in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".avif"]
]
else:
image_files = [input]
size_to_images = defaultdict(list)
file_names = [] # List to keep track of file names
for image_file in image_files:
# image_path = os.path.join(input_dir, image_file)
image, _ = load_image(image_file)
transformed_image = IMAGE_TRANSFORMS(image)
size_to_images[transformed_image.shape[1:]].append(transformed_image)
file_names.append(image_file) # Save the file name
# os.makedirs(output_dir, exist_ok=True)
total_images = len(file_names)
batch_size = args.batch_size
with tqdm.tqdm(total=total_images) as progress_bar:
for size, images_group in size_to_images.items():
# Process images in batches
for i in range(0, len(images_group), batch_size):
batch = images_group[i : i + batch_size]
batch_file_names = file_names[i : i + batch_size]
# Get the batch file names
latents = process_images_group(vae, batch)
if args.consistency_decoder:
consistencydecoder_and_save(
decoder_consistency,
latents,
batch_file_names,
output,
device=vae.device,
)
else:
decode_vae_and_save(
vae, latents, batch_file_names, output, gif=args.gif
)
progress_bar.update(1)
def decode_vae_and_save(vae, latents, filenames, output, gif=False):
with torch.no_grad():
decoded_images = []
for i in range(0, 1, 1):
decoded_images.append(
vae.decode(
latents[i : i + 1] if i > 1 else latents[i].unsqueeze(0)
).sample
)
decoded_images = torch.cat(decoded_images)
# Rescale images from [-1, 1] to [0, 255] and save
decoded_images = (
((decoded_images / 2 + 0.5).clamp(0, 1) * 255)
.cpu()
.permute(0, 2, 3, 1)
.numpy()
.astype("uint8")
)
for j, decoded_image in enumerate(decoded_images):
original_file = filenames[j] # Get the original file name for each image
output_file = (
output.absolute()
/ original_file.with_name(f"{original_file.stem}-latents-decoded.png").name
)
output_image = Image.fromarray(decoded_image)
output_image.save(output_file)
if gif:
output_gif_file = (
output.absolute()
/ original_file.with_name(
f"{original_file.stem}-latents-decoded.gif"
).name
)
Image.open(original_file).save(
output_gif_file,
save_all=True,
append_images=[output_image],
duration=500,
loop=0,
)
def consistencydecoder_and_save(
decoder_consistency, latents, filenames, output_dir, device
):
from consistencydecoder import save_image
with torch.no_grad():
sample_consistences = decoder_consistency(latents)
for j, decoded_image in enumerate(sample_consistences):
original_file_name = filenames[
j
] # Get the original file name for each image
original_name_without_extension = os.path.splitext(original_file_name)[0]
save_image(
decoded_image,
os.path.join(
output_dir,
f"{original_name_without_extension}-latents-decoded-consistency.png",
),
)
def main(args):
device = torch.device(args.device)
if args.vae is None:
if args.sdxl:
# putting this in here just to be able to pass the argument
from accelerate import Accelerator
accelerator = Accelerator()
_, _, _, vae, _, _, _ = sdxl_train_util.load_target_model(
args,
accelerator,
args.pretrained_model_name_or_path,
torch.float16,
)
else:
# Load model's VAE
_, vae, _ = model_util.load_models_from_stable_diffusion_checkpoint(
args.v2,
args.pretrained_model_name_or_path,
)
vae.to(device, dtype=torch.float32)
else:
vae = model_util.load_vae(args.vae, torch.float32).to(device)
# Save image decoded latents
process_latents_from_images(vae, args.input_file_or_dir, args.output_dir, args)
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--device", default="cpu")
argparser.add_argument(
"--input_file_or_dir", help="Input file or directory to load the images from"
)
argparser.add_argument(
"--output_dir", help="Output directory to put the VAE decoded images"
)
argparser.add_argument(
"--vae", default="", help="Path to VAE file or hugging face VAE path"
)
argparser.add_argument(
"--pretrained_model_name_or_path",
default="",
help="Stable diffusion model name or path to load the VAE from.",
)
argparser.add_argument(
"--gif",
action="store_true",
help="Make a gif of the decoded image with the original",
)
argparser.add_argument(
"--v2", action="store_true", help="Is a Stable Diffusion v2 model."
)
argparser.add_argument(
"--batch_size", type=int, default=1, help="Batch size to process the images."
)
argparser.add_argument(
"--sdxl", action="store_true", help="(NOTWORKING) SDXL model"
)
argparser.add_argument("--lowram", type=int, default=1, help="SDXL low ram option")
argparser.add_argument(
"--full_fp16", type=int, default=1, help="SDXL use full fp16"
)
argparser.add_argument(
"--full_bf16", type=int, default=1, help="SDXL use full bf16"
)
argparser.add_argument(
"--consistency_decoder",
action="store_true",
help="Use Consistency Decoder from OpenAI https://github.com/openai/consistencydecoder",
)
args = argparser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment