Last active
November 2, 2023 05:45
-
-
Save woshiyyya/9f631997be26b2312060e8b858960fbf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Cluster: 16 x A10G GPUs | |
Command: python precompute_latents.py --subset_size 50 --mode debug | |
""" | |
import argparse | |
import io | |
import pandas as pd | |
import pyarrow.dataset as pds | |
import os | |
import ray | |
import torch | |
from diffusers import AutoencoderKL | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import crop | |
from transformers import CLIPTextModel, CLIPTokenizer | |
class LargestCenterSquare: | |
"""Center crop to the largest square of a PIL image.""" | |
def __init__(self, size): | |
self.size = size | |
self.center_crop = transforms.CenterCrop(self.size) | |
def __call__(self, img): | |
# First, resize the image such that the smallest side is self.size while preserving aspect ratio. | |
img = transforms.functional.resize(img, self.size, antialias=True) | |
# Then take a center crop to a square. | |
w, h = img.size | |
c_top = (h - self.size) // 2 | |
c_left = (w - self.size) // 2 | |
img = crop(img, c_top, c_left, self.size, self.size) | |
return img | |
def transform_images(batch: pd.DataFrame) -> pd.DataFrame: | |
"""Transform the images and filter out the invalid ones. """ | |
# Image transformation | |
center_square_crop = LargestCenterSquare(256) | |
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
transform = transforms.Compose( | |
[center_square_crop, transforms.ToTensor(), normalize] | |
) | |
def process_row(row): | |
try: | |
image = Image.open(io.BytesIO(row["jpg"])) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
row["valid"] = True | |
row["image"] = transform(image).tolist() | |
except Exception as e: | |
row["valid"] = False | |
row["image"] = row["caption_ids"] = None | |
return row | |
# Transform raw images and text captions | |
batch = batch.apply(process_row, axis=1) | |
return batch[batch["valid"]] | |
class LatentEncoder: | |
def __init__(self, model_name="stabilityai/stable-diffusion-2-base"): | |
print( | |
"CUDA_VISIBLE_DEVICES = ", | |
os.environ["CUDA_VISIBLE_DEVICES"], | |
"GPU IDS = ", | |
ray.get_gpu_ids(), | |
) | |
# self.device_id = ray.get_gpu_ids()[0] | |
self.device = "cuda" | |
# Image and text encoders | |
self.vae = AutoencoderKL.from_pretrained( | |
model_name, subfolder="vae", torch_dtype=torch.float16 | |
) | |
self.text_tokenizer = CLIPTokenizer.from_pretrained( | |
model_name, subfolder="tokenizer" | |
) | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
model_name, subfolder="text_encoder", torch_dtype=torch.float16 | |
) | |
# Move the encoders to GPU | |
self.vae = self.vae.to(self.device) | |
self.text_encoder = self.text_encoder.to(self.device) | |
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: | |
caption_list = batch["caption"].tolist() | |
image_list = batch["image"].tolist() | |
tokenized_caption = self.text_tokenizer( | |
caption_list, | |
padding="max_length", | |
max_length=self.text_tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
# Construct input tensors for image and caption token ids | |
image_tensor = torch.tensor(image_list).to(self.device) | |
caption_ids_tensor = tokenized_caption["input_ids"].to(self.device) | |
with torch.no_grad(): | |
# Encode images | |
image_latents_tensor = ( | |
self.vae.encode(image_tensor.half())["latent_dist"].sample() * 0.18215 | |
) | |
image_latents_numpy = image_latents_tensor.detach().cpu().numpy() | |
# Encode captions | |
caption_latents_tensor = self.text_encoder(caption_ids_tensor)[0] | |
caption_latents_numpy = caption_latents_tensor.detach().cpu().numpy() | |
# Serialize latents tensors to bytes | |
# Shape = [batch_size, 4, 32, 32] | |
batch["latents_256_bytes"] = [ | |
latents.tobytes() for latents in image_latents_numpy | |
] | |
# Shape = [batch_size, 77, 1024] | |
batch["caption_latents"] = [ | |
latents.tobytes() for latents in caption_latents_numpy | |
] | |
return batch | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset_name", default="800G-LAION-art-8M", type=str) | |
parser.add_argument("--mode", default="debug", type=str) | |
parser.add_argument( | |
"--subset_size", | |
default=10, | |
type=int, | |
required=True, | |
help="Number of samples to generate for the output subset. Unit=K", | |
) | |
args = parser.parse_args() | |
# success rate = 0.74 | |
num_samples_per_parquet = 1e4 * 0.74 | |
num_input_parquets = int(args.subset_size * 1e3 // num_samples_per_parquet) + 1 | |
print("num_input_parquets", num_input_parquets) | |
base_s3_uri = f"s3://air-example-data-2/{args.dataset_name}" | |
input_parquet_uris = [f"{base_s3_uri}/{i:05}.parquet" for i in range(num_input_parquets)] | |
# Data Pipeline | |
# 0. Read input parquet files from S3 | |
ds = ray.data.read_parquet( | |
input_parquet_uris, filter=(pds.field("status") == "success") | |
) | |
# 1. Clean and transform images | |
ds = ds.map_batches(transform_images, batch_size=32, batch_format="pandas") | |
# 2. Encode images and captions | |
ds_processed = ds.map_batches( | |
LatentEncoder, | |
compute=ray.data.ActorPoolStrategy( | |
size=16 | |
), # Use 16 GPUs. Change this number based on the number of GPUs in your cluster. | |
num_gpus=1, # Specify 1 GPU per model replica. | |
batch_size=128, # Use the largest batch size that can fit on our GPUs | |
batch_format="pandas", | |
) | |
# 3. Ensure each parquet file contains 1000 samples | |
df_processed = ds_processed.map_batches(lambda x: x, batch_size=1000) | |
# 4. Dump to S3 bucket | |
if args.mode == "debug": | |
output_uri = ( | |
f"s3://air-example-data-2/LAION-precomputed/debug/LAION-dev-{args.subset_size}K/" | |
) | |
else: | |
output_uri = ( | |
f"s3://air-example-data-2/LAION-precomputed/LAION-dev-{args.subset_size}K/" | |
) | |
ds_processed.write_parquet(output_uri, try_create_dir=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment