Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Created November 7, 2023 20:49
Show Gist options
  • Save rockerBOO/f1b3c18f4b9fc161310b47d9cb39fcba to your computer and use it in GitHub Desktop.
Save rockerBOO/f1b3c18f4b9fc161310b47d9cb39fcba to your computer and use it in GitHub Desktop.
# Original from https://gist.github.com/Poiuytrezay1/db6b98672675456bed39d45077d44179
# Credit to Poiuytrezay1
import argparse
import os
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import tqdm
from PIL import Image
from torchvision import transforms
import library.model_util as model_util
import library.sdxl_train_util as sdxl_train_util
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img, image.info
def process_images_group(vae, images_group):
with torch.no_grad():
# Stack the tensors from the same size group
img_tensors = torch.stack(images_group, dim=0).to(vae.device)
# Encode and decode the images
latents = vae.encode(img_tensors).latent_dist.sample()
return latents
def process_latents_from_images(vae, input_file_or_dir, output_dir, args):
if args.consistency_decoder:
from consistencydecoder import ConsistencyDecoder
decoder_consistency = ConsistencyDecoder(device=vae.device)
input = Path(input_file_or_dir)
output = Path(output_dir)
os.makedirs(str(output.absolute()), exist_ok=True)
if input.is_dir():
image_files = [
file
for file in input.iterdir()
if file.suffix in [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".avif"]
]
else:
image_files = [input]
size_to_images = defaultdict(list)
file_names = [] # List to keep track of file names
for image_file in image_files:
image, _ = load_image(image_file)
transformed_image = IMAGE_TRANSFORMS(image)
size_to_images[transformed_image.shape[1:]].append(transformed_image)
file_names.append(image_file) # Save the file name
total_images = len(file_names)
# batch_size = args.batch_size
batch_size = 1
print("Temporarily limiting batch size to 1")
vae_name = Path(args.vae).stem if args.vae is not None else None
with tqdm.tqdm(total=total_images) as progress_bar:
for i, (size, images_group) in enumerate(size_to_images.items()):
batch_file_names = file_names[i : i + batch_size]
# Get the batch file names
latents = process_images_group(vae, images_group)
if args.consistency_decoder:
consistencydecoder_and_save(
decoder_consistency,
latents,
batch_file_names,
output,
device=vae.device,
)
else:
decode_vae_and_save(
vae,
latents,
batch_file_names,
output,
gif=args.gif,
vae_name=vae_name,
)
progress_bar.update(1)
def decode_vae_and_save(vae, latents, filenames, output, gif=False, vae_name=None):
with torch.no_grad():
decoded_images = []
for i in range(0, 1, 1):
decoded_images.append(
vae.decode(
latents[i : i + 1] if i > 1 else latents[i].unsqueeze(0)
).sample
)
decoded_images = torch.cat(decoded_images)
# Rescale images from [-1, 1] to [0, 255] and save
decoded_images = (
((decoded_images / 2 + 0.5).clamp(0, 1) * 255)
.cpu()
.permute(0, 2, 3, 1)
.numpy()
.astype("uint8")
)
vae_file_part = f"-{vae_name}" if vae_name is not None else ""
for j, decoded_image in enumerate(decoded_images):
original_file = filenames[j] # Get the original file name for each image
output_file = (
output.absolute()
/ original_file.with_name(
f"{original_file.stem}-latents-decoded{vae_file_part}.png"
).name
)
output_image = Image.fromarray(decoded_image)
print(f"Saving to {output_file}")
output_image.save(output_file)
if gif:
output_gif_file = (
output.absolute()
/ original_file.with_name(
f"{original_file.stem}-latents-decoded{vae_file_part}.gif"
).name
)
Image.open(original_file).save(
output_gif_file,
save_all=True,
append_images=[output_image],
duration=500,
loop=0,
)
def consistencydecoder_and_save(
decoder_consistency, latents, filenames, output_dir, device
):
from consistencydecoder import save_image
with torch.no_grad():
sample_consistences = decoder_consistency(latents)
for j, decoded_image in enumerate(sample_consistences):
original_file_name = filenames[
j
] # Get the original file name for each image
original_name_without_extension = os.path.splitext(original_file_name)[0]
save_image(
decoded_image,
os.path.join(
output_dir,
f"{original_name_without_extension}-latents-decoded-consistency.png",
),
)
def main(args):
device = torch.device(args.device)
if args.vae is None:
if args.sdxl:
# putting this in here just to be able to pass the argument
from accelerate import Accelerator
accelerator = Accelerator()
_, _, _, vae, _, _, _ = sdxl_train_util.load_target_model(
args,
accelerator,
args.pretrained_model_name_or_path,
torch.float16,
)
else:
# Load model's VAE
_, vae, _ = model_util.load_models_from_stable_diffusion_checkpoint(
args.v2,
args.pretrained_model_name_or_path,
)
vae.to(device, dtype=torch.float32)
else:
vae = model_util.load_vae(args.vae, torch.float32).to(device)
# Save image decoded latents
process_latents_from_images(vae, args.input_file_or_dir, args.output_dir, args)
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--device", default="cpu")
argparser.add_argument(
"--input_file_or_dir", help="Input file or directory to load the images from"
)
argparser.add_argument(
"--output_dir", help="Output directory to put the VAE decoded images"
)
argparser.add_argument(
"--vae", default="", help="Path to VAE file or hugging face VAE path"
)
argparser.add_argument(
"--pretrained_model_name_or_path",
default="",
help="Stable diffusion model name or path to load the VAE from.",
)
argparser.add_argument(
"--gif",
action="store_true",
help="Make a gif of the decoded image with the original",
)
argparser.add_argument(
"--v2", action="store_true", help="Is a Stable Diffusion v2 model."
)
argparser.add_argument(
"--batch_size", type=int, default=1, help="Batch size to process the images."
)
argparser.add_argument(
"--sdxl", action="store_true", help="(NOTWORKING) SDXL model"
)
argparser.add_argument("--lowram", type=int, default=1, help="SDXL low ram option")
argparser.add_argument(
"--full_fp16", type=int, default=1, help="SDXL use full fp16"
)
argparser.add_argument(
"--full_bf16", type=int, default=1, help="SDXL use full bf16"
)
argparser.add_argument(
"--consistency_decoder",
action="store_true",
help="Use Consistency Decoder from OpenAI https://github.com/openai/consistencydecoder",
)
args = argparser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment