Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created September 17, 2023 11:36
Show Gist options
  • Save cloneofsimo/85f763e06b67815278180a7856a10fa6 to your computer and use it in GitHub Desktop.
Save cloneofsimo/85f763e06b67815278180a7856a10fa6 to your computer and use it in GitHub Desktop.
preprocess-videos-latents
import os
import csv
import torch
import cv2
import logging
from typing import Tuple, Any, List
from torch.utils.data import DataLoader, Dataset
from multiprocessing import Pool
from streaming import MDSWriter
import ImageReward as RM
from PIL import Image
from diffusers.models import AutoencoderKL
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration
import pandas as pd
import json
import time
# Initialize logging
logging.basicConfig(level=logging.INFO)
from streaming.base.format.mds.encodings import Encoding, _encodings
class bf16(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.float16)
_encodings['bf16'] = bf16
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class VideoDataset(Dataset):
def __init__(self, csv_file):
self.video_files = pd.read_csv(csv_file)['video_path'].to_list()
#print(self.video_files)
self.dataset_latency = []
def __len__(self):
return len(self.video_files)
def __getitem__(self, idx):
start_time = time.time()
video_file = self.video_files[idx]
frames = self._load_frames_from_video(video_file)
second_image = Image.fromarray(frames[1])
second_image = self._center_crop_square_resize(second_image, 512)
tiled_image = self._tile_frames(frames)
# Check if tiled_image has repeated same images
diff = np.array([np.abs(frames[i] - frames[i+1]).sum() for i in range(3)])
if np.all(diff < 1e-5):
return None, None
self.dataset_latency.append(time.time() - start_time)
return second_image, prepare_image(tiled_image, 1024, 1024)
def _load_frames_from_video(self, video_path):
vid = cv2.VideoCapture(video_path)
total_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(vid.get(cv2.CAP_PROP_FPS))
quarter_second_interval = fps // 4
interval = max(1, quarter_second_interval)
frames = []
for i in range(4):
vid.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
ret, frame = vid.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame_rgb)
vid.release()
return frames
def _center_crop_square_resize(self, image, output_size):
# Resize image while maintaining aspect ratio
aspect = image.width / image.height
if aspect > 1:
# Landscape orientation - wide image
width = int(output_size * aspect)
height = output_size
else:
# Portrait orientation - tall image
width = output_size
height = int(output_size / aspect)
image = image.resize((width, height), Image.BICUBIC)
# Center crop to 512x512
left = (image.width - output_size) / 2
top = (image.height - output_size) / 2
right = (image.width + output_size) / 2
bottom = (image.height + output_size) / 2
return image.crop((left, top, right, bottom))
def _tile_frames(self, frames):
tiled_image = Image.new('RGB', (512 * 2, 512 * 2))
# Assuming all frames are of the same size
for i, frame in enumerate(frames):
pil_frame = Image.fromarray(frame)
# Center-square resize & center-crop each frame before pasting
pil_frame = self._center_crop_square_resize(pil_frame, 512)
tiled_image.paste(pil_frame, (512 * (i % 2), 512 * (i // 2)))
return tiled_image
@torch.no_grad()
def convert_to_mds(args: Tuple[List[str], torch.device]):
sub_out_roots, device = args
logging.info(f"Processing on {device}")
logging.info(f"Processing {sub_out_roots}")
# Set the device for the current process
torch.cuda.set_device(device)
# Initialize the models
image_reward_model = RM.load("ImageReward-v1.0").to(device).eval()
# vae model
vae_model = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").half()
vae_model = vae_model.to(device).eval()
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device).eval()
# Load the dataset
for sub_out_root in sub_out_roots:
dataset = VideoDataset(os.path.join(sub_out_root, 'data.csv'))
sub_data_root = os.path.join(sub_out_root, 'data')
columns = {
'reward_output': 'float32',
'vae_output': 'bf16',
'caption_output': 'str'
}
if os.path.exists(sub_data_root):
# remove all files in the directory
for file in os.listdir(sub_data_root):
os.remove(os.path.join(sub_data_root, file))
os.makedirs(sub_data_root, exist_ok=True)
with MDSWriter(out=sub_data_root, columns=columns) as out:
inference_latencies = []
for data_mid, data_all in dataset:
if data_mid is None:
continue
data_all = data_all.to(device)
start_time = time.time()
blip_inputs = blip_processor(data_mid, text = " ", return_tensors="pt").to(torch.float16).to(device)
blip_out = blip_model.generate(**blip_inputs, max_new_tokens=25, min_new_tokens=4, do_sample=True, top_k=50, temperature=0.7)
generated_captions = blip_processor.batch_decode(blip_out, skip_special_tokens=True)[0]
reward_output = image_reward_model.score(generated_captions, data_mid)
vae_output = vae_model.encode(data_all.half()).latent_dist.sample()
print(reward_output, vae_output.shape, generated_captions)
#Save the outputs to MDS
sample = {
'reward_output': reward_output,
'vae_output': vae_output.cpu().half().numpy(),
'caption_output': generated_captions,
}
out.write(sample)
inference_latencies.append(time.time() - start_time)
print(f"Average Inference Latency on {device}: {np.mean(inference_latencies)} seconds")
print(f"Average Dataset Processing Latency {device}: {np.mean(dataset.dataset_latency)} seconds")
return True
def init_worker():
pid = os.getpid()
print(f'\nInitialize Worker PID: {pid}', flush=True, end='')
def main(video_files: List[str], out_root):
# Group into batches of 47
grouped_datasets = [video_files[i:i + 1024] for i in range(0, len(video_files), 1024)]
# Make sure we have enough groups for our GPUs
num_gpus = torch.cuda.device_count()
assert len(grouped_datasets) >= num_gpus, f"Not enough data for {num_gpus} GPUs."
# Preprocess videos to CSV
os.makedirs(out_root, exist_ok=True)
grouped_paths = []
for i, dataset_group in enumerate(grouped_datasets):
group_path = os.path.join(out_root, f'group_{i}')
os.makedirs(group_path, exist_ok=True)
df = pd.DataFrame(dataset_group, columns=['video_path'])
csvpath = os.path.join(group_path, 'data.csv')
df.to_csv(csvpath, index=False)
grouped_paths.append(group_path)
print(grouped_paths)
logging.info("Videos preprocessed to CSV files.")
# Create a round-robin GPU assignment
gpu_assignments = [([grouped_paths[i]], torch.device(f'cuda:{i % num_gpus}')) for i in range(len(grouped_paths))]
with Pool(num_gpus, initializer=init_worker) as pool:
pool.map(convert_to_mds, gpu_assignments)
print('Finished')
@cloneofsimo
Copy link
Author

from processor import main
import os
import json

if __name__ == "__main__":
    import torch.multiprocessing as mp
    mp.set_start_method('spawn')
    
    def extract_video_paths(file_path):
        video_paths = []
        with open(file_path, 'r') as f:
            for line in f:
                obj = json.loads(line)
                video_path = obj.get('video_path')
                if video_path:
                    video_paths.append(video_path)
        return video_paths

    file_path = "/scratch/slurm-user9-unmanned/vim/_webvid_2500k/results_2M_train_fps2_frames8.jsonl"
    video_files = extract_video_paths(file_path)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment