Skip to content

Instantly share code, notes, and snippets.

@eddyxu
Last active April 12, 2024 19:19
Show Gist options
  • Save eddyxu/a3cb9097ae7008e522ef3a2f47834861 to your computer and use it in GitHub Desktop.
Save eddyxu/a3cb9097ae7008e522ef3a2f47834861 to your computer and use it in GitHub Desktop.
hd-vila-ray
#!/usr/bin/env python3
#
# Generate HD-Vila-100M dataset
#
# https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m
import argparse
import datetime
import logging
from collections import defaultdict
from pathlib import Path
from subprocess import check_call
from tempfile import TemporaryDirectory
from typing import Any, Generator, List, Tuple
import pyarrow as pa
import ray
import yt_dlp as youtube_dl
from lance.ray.sink import LanceCommitter, LanceFragmentWriter
FORMAT_IDS = {
"720p": "22",
}
ZERO_DATETIME = datetime.datetime.strptime("00:00:00.000", "%H:%M:%S.%f")
SCHEMA = pa.schema(
[
pa.field("video_id", pa.string()),
pa.field("clip_id", pa.string()),
pa.field("start", pa.float32()),
pa.field("duration", pa.float32()),
pa.field("video", pa.binary()),
]
)
def parse_span(span_str: List[str]) -> Tuple[float, float]:
"""Return (start, duration)"""
start, end = [datetime.datetime.strptime(s, "%H:%M:%S.%f") for s in span_str]
duration = (end - start).total_seconds()
start = (start - ZERO_DATETIME).total_seconds()
return (start, duration)
def download_youtube(
video_id: str,
output_dir: Path,
format_id: int = FORMAT_IDS["720p"],
ext: str = "mp4",
) -> Path | None:
"""Return the path of downloaded video."""
dl_opts = {
"outtmpl": str(output_dir / "%(id)s.%(ext)s"),
"merge_output_format": ext,
"format": format_id,
"skip_download": False,
"ignoreerrors": True,
"quiet": True,
"max_sleep_interval": 15,
}
with youtube_dl.YoutubeDL(dl_opts) as ydl:
rst = ydl.download([f"https://www.youtube.com/watch?v={video_id}"])
if rst != 0:
logging.error("Failed to download video %s", video_id)
return None
return output_dir / f"{video_id}.{ext}"
def cut_clip(video_path: Path, start: float, duration: float, output_path: Path):
"""Cut video clip.
Parameters
----------
video_path : Path
Path to video file
start : float
Start time in seconds
duration : float
Duration in seconds
output_path : Path
Output path
"""
check_call(
[
"ffmpeg",
"-y", # Overwrite
"-ss",
str(start),
"-i",
str(video_path),
"-t",
str(duration),
"-c",
"copy",
"-loglevel",
"quiet",
str(output_path),
]
)
def transform_clip(batch: pa.Table) -> Generator[None, dict[str, Any], None]:
"""Transform a row of the dataset."""
for idx, video_id in enumerate(batch["video_id"]):
results = defaultdict(list)
with TemporaryDirectory(suffix=".mp4") as dl_dir:
video_path = download_youtube(video_id, Path(dl_dir))
if video_path is None:
continue
for clip in batch["clip"][idx]:
clip_id = clip["clip_id"]
start, duration = parse_span(clip["span"])
logging.debug(
"Cutting %s: start=%s duration=%s", video_id, start, duration
)
clip_file: Path = Path(dl_dir) / f"{clip_id}.mp4"
cut_clip(video_path, start, duration, clip_file)
with clip_file.open("rb") as f:
clip_video = f.read()
results["video_id"].append(video_id)
results["clip_id"].append(clip_id)
results["start"].append(start)
results["duration"].append(duration)
results["video"].append(clip_video)
yield results
def dataset_generator(args):
obj_store_memory = args.ray_object_store_memory_gb * 1024**3
ray.init(object_store_memory=obj_store_memory)
context = ray.data.DataContext.get_current()
context.execution_options.resource_limits.object_store_memory = obj_store_memory
ds = ray.data.read_json(args.input, override_num_blocks=8)
if args.limit:
ds = ds.limit(args.limit)
ds = (
ds.repartition(8)
.materialize()
.map_batches(
LanceFragmentWriter(
args.output,
transform=transform_clip,
max_rows_per_group=128, # Only for format V1.
max_rows_per_file=4096,
),
batch_size=1024, # each video split into 10-50 clips
)
.write_datasink(LanceCommitter(args.output, SCHEMA))
)
def main():
parser = argparse.ArgumentParser(
description="Generate HD-Vila-100M dataset in Lance",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-o",
"--output",
help="Output directory",
default="hd-vila.lance",
metavar="DIR",
)
parser.add_argument(
"-l",
"--limit",
help="limit number of video to proceed",
default=None,
metavar="NUM",
type=int,
)
parser.add_argument(
"-c",
"--concurrency",
help="How many Ray workers to use",
default=8,
type=int,
metavar="NUM",
)
parser.add_argument(
"-b", "--batch", help="batch size", default=10, type=int, metavar="NUM"
)
parser.add_argument(
"--max-rows-per-file",
help="max rows per file",
default=4 * 1024,
type=int,
metavar="NUM",
)
parser.add_argument(
"--ray-object-store-memory-gb",
type=int,
default=8,
metavar="NUM",
help="Ray object store memory in GB",
)
parser.add_argument(
"input",
help="Input files (HD-Vila-100M dataset)",
nargs="+",
metavar="PART.JSONL",
)
args = parser.parse_args()
dataset_generator(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment