Skip to content

Instantly share code, notes, and snippets.

@YodaEmbedding
Last active February 26, 2024 06:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YodaEmbedding/8803d95de072f12b4ff14ffd2b5bd7e5 to your computer and use it in GitHub Desktop.
Save YodaEmbedding/8803d95de072f12b4ff14ffd2b5bd7e5 to your computer and use it in GitHub Desktop.
Generate Vimeo90K NumPy dataset
import argparse
from pathlib import Path
import numpy as np
import torch
from compressai.datasets import Vimeo90kDataset
from torch.utils.data import DataLoader
from torchvision import transforms
MESSAGE = """
Download and extract the Vimeo90k dataset first:
mkdir -p vimeo90k
cd vimeo90k
wget http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip
unzip vimeo_triplet.zip
wget http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip
unzip vimeo_septuplet.zip
cd ..
Then, run one of the following:
python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=image --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy"
python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=video --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy_video"
python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=image --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy"
python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=video --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy_video"
If the mode is "image", each frame is treated separately, and may
undergo different transformations.
If the mode is "video", all frames undergo the same transformation.
"""
PATCH_LENGTH = 256
PATCH_SIZE = (PATCH_LENGTH, PATCH_LENGTH)
FILENAMES = {
"train": "training",
"valid": "validation",
}
def get_dataset(dataset_path, split, tuplet, mode):
crop = (
transforms.RandomCrop(PATCH_SIZE)
if split == "train"
else transforms.CenterCrop(PATCH_SIZE)
)
chw_to_hwc = (
lambda x: x.permute(1, 2, 0)
if mode == "image"
else x.permute(0, -2, -1, -3)
if mode == "video"
else None
)
transform = transforms.Compose(
[
crop,
# lambda img: torch.from_numpy(np.array(img)),
chw_to_hwc,
# transforms.ToTensor(), # NOTE: Converts HWC -> CHW.
]
)
dataset = Vimeo90kDataset(
root=dataset_path,
transform=transform,
split=split,
tuplet=tuplet,
# The following parameters are experimental.
# Old versions of CompressAI do not have these,
# and behave as if mode="image".
mode=mode,
transform_frame=transforms.ToTensor(), # NOTE: Converts HWC -> CHW.
)
loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=8)
return dataset, loader
def generate_npy_dataset(indir, outdir, split, tuplet, mode, epochs):
dataset, loader = get_dataset(indir, split, tuplet, mode)
out_filepath = Path(f"{outdir}/{FILENAMES[split]}.npy")
out_filepath.parent.mkdir(exist_ok=True)
print(f"Writing to {out_filepath}...")
if mode == "image":
shape = (epochs * len(dataset), *PATCH_SIZE, 3)
elif mode == "video":
shape = (epochs * len(dataset), tuplet, *PATCH_SIZE, 3)
x_out = np.memmap(out_filepath, dtype="uint8", mode="w+", shape=shape)
offset = 0
for epoch in range(epochs):
for i, x in enumerate(loader):
x = (x * 255).to(torch.uint8)
print(
f"{split} | "
f"{epoch} / {epochs} epochs | "
f"{offset:6d} / {len(dataset)} items | "
f"{i:6d} / {len(loader)} batches | "
# For ensuring that random output is stable:
f"checksum: {x.min():3.0f} {x.max():3.0f} {x.to(float).mean():3.0f}"
)
x_out[offset : offset + len(x)] = x.numpy()
offset += len(x)
x_out.flush()
del x_out
def parse_args():
parser = argparse.ArgumentParser(description="Generate Vimeo90k dataset")
parser.add_argument("--indir", default="vimeo90k/vimeo_triplet")
parser.add_argument("--outdir", default="vimeo90k/vimeo_triplet_npy")
parser.add_argument("--tuplet", type=int, default=3)
parser.add_argument("--mode", default="image", choices=["image", "video"])
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--epochs", type=int, default=1)
return parser.parse_args()
def main():
print(MESSAGE)
args = parse_args()
torch.manual_seed(args.seed)
for split in ["train", "valid"]:
generate_npy_dataset(
indir=args.indir,
outdir=args.outdir,
split=split,
tuplet=args.tuplet,
mode=args.mode,
epochs=args.epochs,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment