Skip to content

Instantly share code, notes, and snippets.

Created December 17, 2023 13:02
Show Gist options
  • Save btlorch/5ddea5e6a0951995d09536fbc95a3dfd to your computer and use it in GitHub Desktop.
Save btlorch/5ddea5e6a0951995d09536fbc95a3dfd to your computer and use it in GitHub Desktop.
Compress images with the SDXL auto-encoder
import argparse
from PIL import Image
import numpy as np
from glob import glob
import os
from tqdm import tqdm
from diffusers import DiffusionPipeline
import torch
import torchvision.transforms as T
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def prepare_image(filepath):
# Read PIL image
img =
# Convert to tensor in range [0, 1]
img_torch = T.ToTensor()(img).to(DEVICE)
# Add singleton batch dimension
img_batch = torch.unsqueeze(img_torch, dim=0)
# Normalize to range [-1, 1]
img_batch = img_batch * 2. - 1.
return img_batch
def tensor_to_pil(x):
x = x.detach().cpu()
# Clip to range [-1, 1]
x = torch.clamp(x, -1., 1.)
# Scale t orange [0, 1]
x = (x + 1.) / 2.
# Move channel axis to the end
x = x.permute(1, 2, 0).numpy()
# Scale to uint8 range
x = (255 * x).astype(np.uint8)
# Convert to PIL image
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def load_vae():
# Load the SDXL base model
base = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32, variant="fp16", use_safetensors=True)
# We are only interested in the auto-encoder
vae = base.vae
return vae
def compress_decompress(filepaths, output_dir, vae):
for filepath in tqdm(filepaths):
output_filepath = os.path.join(output_dir, os.path.splitext(os.path.basename(filepath))[0] + ".png")
if os.path.exists(output_filepath):
print(f"Skipping because output file \"{output_filepath}\" already exists")
# Load image and convert to range [-1, +1]
img_batch = prepare_image(filepath)
# Feed through auto-encoder
img_reconstructed_batch = vae(img_batch)["sample"]
# Convert back to PIL image
img_reconstructed = tensor_to_pil(img_reconstructed_batch[0])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, help="Path to input directory", default="ALASKA_v2_TIFF_512_COLOR")
parser.add_argument("--output_dir", type=str, help="Path to output directory", default="/tmp")
parser.add_argument("--max_num_images", type=int, help="Take only a limited number of samples")
args = vars(parser.parse_args())
filepaths = sorted(glob(os.path.join(args["input_dir"], "*.tif")))
if args["max_num_images"]:
filepaths = filepaths[:args["max_num_images"]]
vae = load_vae()
compress_decompress(filepaths, args["output_dir"], vae=vae)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment