Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active November 2, 2023 05:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woshiyyya/9f631997be26b2312060e8b858960fbf to your computer and use it in GitHub Desktop.
Save woshiyyya/9f631997be26b2312060e8b858960fbf to your computer and use it in GitHub Desktop.
"""
Cluster: 16 x A10G GPUs
Command: python precompute_latents.py --subset_size 50 --mode debug
"""
import argparse
import io
import pandas as pd
import pyarrow.dataset as pds
import os
import ray
import torch
from diffusers import AutoencoderKL
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import crop
from transformers import CLIPTextModel, CLIPTokenizer
class LargestCenterSquare:
"""Center crop to the largest square of a PIL image."""
def __init__(self, size):
self.size = size
self.center_crop = transforms.CenterCrop(self.size)
def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square.
w, h = img.size
c_top = (h - self.size) // 2
c_left = (w - self.size) // 2
img = crop(img, c_top, c_left, self.size, self.size)
return img
def transform_images(batch: pd.DataFrame) -> pd.DataFrame:
"""Transform the images and filter out the invalid ones. """
# Image transformation
center_square_crop = LargestCenterSquare(256)
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transform = transforms.Compose(
[center_square_crop, transforms.ToTensor(), normalize]
)
def process_row(row):
try:
image = Image.open(io.BytesIO(row["jpg"]))
if image.mode != "RGB":
image = image.convert("RGB")
row["valid"] = True
row["image"] = transform(image).tolist()
except Exception as e:
row["valid"] = False
row["image"] = row["caption_ids"] = None
return row
# Transform raw images and text captions
batch = batch.apply(process_row, axis=1)
return batch[batch["valid"]]
class LatentEncoder:
def __init__(self, model_name="stabilityai/stable-diffusion-2-base"):
print(
"CUDA_VISIBLE_DEVICES = ",
os.environ["CUDA_VISIBLE_DEVICES"],
"GPU IDS = ",
ray.get_gpu_ids(),
)
# self.device_id = ray.get_gpu_ids()[0]
self.device = "cuda"
# Image and text encoders
self.vae = AutoencoderKL.from_pretrained(
model_name, subfolder="vae", torch_dtype=torch.float16
)
self.text_tokenizer = CLIPTokenizer.from_pretrained(
model_name, subfolder="tokenizer"
)
self.text_encoder = CLIPTextModel.from_pretrained(
model_name, subfolder="text_encoder", torch_dtype=torch.float16
)
# Move the encoders to GPU
self.vae = self.vae.to(self.device)
self.text_encoder = self.text_encoder.to(self.device)
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
caption_list = batch["caption"].tolist()
image_list = batch["image"].tolist()
tokenized_caption = self.text_tokenizer(
caption_list,
padding="max_length",
max_length=self.text_tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# Construct input tensors for image and caption token ids
image_tensor = torch.tensor(image_list).to(self.device)
caption_ids_tensor = tokenized_caption["input_ids"].to(self.device)
with torch.no_grad():
# Encode images
image_latents_tensor = (
self.vae.encode(image_tensor.half())["latent_dist"].sample() * 0.18215
)
image_latents_numpy = image_latents_tensor.detach().cpu().numpy()
# Encode captions
caption_latents_tensor = self.text_encoder(caption_ids_tensor)[0]
caption_latents_numpy = caption_latents_tensor.detach().cpu().numpy()
# Serialize latents tensors to bytes
# Shape = [batch_size, 4, 32, 32]
batch["latents_256_bytes"] = [
latents.tobytes() for latents in image_latents_numpy
]
# Shape = [batch_size, 77, 1024]
batch["caption_latents"] = [
latents.tobytes() for latents in caption_latents_numpy
]
return batch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", default="800G-LAION-art-8M", type=str)
parser.add_argument("--mode", default="debug", type=str)
parser.add_argument(
"--subset_size",
default=10,
type=int,
required=True,
help="Number of samples to generate for the output subset. Unit=K",
)
args = parser.parse_args()
# success rate = 0.74
num_samples_per_parquet = 1e4 * 0.74
num_input_parquets = int(args.subset_size * 1e3 // num_samples_per_parquet) + 1
print("num_input_parquets", num_input_parquets)
base_s3_uri = f"s3://air-example-data-2/{args.dataset_name}"
input_parquet_uris = [f"{base_s3_uri}/{i:05}.parquet" for i in range(num_input_parquets)]
# Data Pipeline
# 0. Read input parquet files from S3
ds = ray.data.read_parquet(
input_parquet_uris, filter=(pds.field("status") == "success")
)
# 1. Clean and transform images
ds = ds.map_batches(transform_images, batch_size=32, batch_format="pandas")
# 2. Encode images and captions
ds_processed = ds.map_batches(
LatentEncoder,
compute=ray.data.ActorPoolStrategy(
size=16
), # Use 16 GPUs. Change this number based on the number of GPUs in your cluster.
num_gpus=1, # Specify 1 GPU per model replica.
batch_size=128, # Use the largest batch size that can fit on our GPUs
batch_format="pandas",
)
# 3. Ensure each parquet file contains 1000 samples
df_processed = ds_processed.map_batches(lambda x: x, batch_size=1000)
# 4. Dump to S3 bucket
if args.mode == "debug":
output_uri = (
f"s3://air-example-data-2/LAION-precomputed/debug/LAION-dev-{args.subset_size}K/"
)
else:
output_uri = (
f"s3://air-example-data-2/LAION-precomputed/LAION-dev-{args.subset_size}K/"
)
ds_processed.write_parquet(output_uri, try_create_dir=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment