Skip to content

Instantly share code, notes, and snippets.

@pinxau1000
Created December 2, 2021 17:50
Show Gist options
  • Save pinxau1000/f3f854199c338b679ea0c13923513607 to your computer and use it in GitHub Desktop.
Save pinxau1000/f3f854199c338b679ea0c13923513607 to your computer and use it in GitHub Desktop.
Python Script for VMAF Computation using VMAF CLI and Multithreading
import argparse
import multiprocessing
import os
import re
import subprocess
from functools import partial
import tqdm
def compute_vmaf(vmaf_cli_path: str, orig_image: str, recon_image: str, output_type: str = "json",
output_dir: str = None, save_output_to_file: bool = False, **kwargs):
match = re.search(r"([\d]+x[\d]+)", os.path.basename(orig_image))
if match:
width, height = match.group().split("x")
else:
return None
pixel_format = os.path.splitext(os.path.basename(orig_image))[0].split("_")[-1].replace("yuv", "").replace("p", "")
if not (output_type == "json" or output_type == "xml" or output_type == "csv" or output_type == "sub"):
output_type = "json" # Defaults to json output file.
if not output_dir:
output_dir = os.path.dirname(recon_image)
os.makedirs(output_dir, exist_ok=True)
out_file_path = os.path.join(
output_dir,
os.path.splitext(os.path.basename(orig_image))[0] + f"_vmaf.{output_type}"
)
cmd = [
vmaf_cli_path,
"-r", orig_image,
"-d", recon_image,
"-w", width,
"-h", height,
"-p", pixel_format,
"-b 8",
"--feature", "psnr",
"--feature", "psnr_hvs",
"--feature", "float_ssim",
"--feature", "float_ms_ssim",
"--feature", "ciede",
"--feature", "cambi",
f"--{output_type}",
"-o", out_file_path
]
# Add extra parameters to VTM DecoderAppStatic command
for k, v in kwargs.items(): # noqa
cmd += [
str(k), str(v)
]
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
if save_output_to_file:
if not output_dir:
output_dir = os.path.dirname(orig_image)
save_path_no_ext = os.path.join(
output_dir,
os.path.splitext(os.path.basename(orig_image))[0] + "_vmaf"
)
with open(f"{save_path_no_ext}.stdout", "w") as fwriter:
fwriter.write(result.stdout.decode("utf-8"))
with open(f"{save_path_no_ext}.stderr", "w") as fwriter:
fwriter.write(result.stderr.decode("utf-8"))
return out_file_path
def compute_vmaf_wrapper(image_pairs: dict, vmaf_cli_path:str, output_type: str, output_dir: str, gen_output: bool):
return compute_vmaf(
vmaf_cli_path=vmaf_cli_path,
orig_image=image_pairs["orig"],
recon_image=image_pairs["recon"],
output_type=output_type,
output_dir=output_dir,
save_output_to_file=gen_output
)
if __name__ == "__main__":
"""
# e.g.
orig_image_path = "../datasets/test/val/munich/munich_000014_000019_leftImg8bit_2048x1024_yuv420p.yuv"
recon_image_path = "../datasets/test/val_QP47/munich_000014_000019_leftImg8bit_2048x1024_yuv420p_bs_decoded.yuv"
file_path = compute_vmaf(orig_image_path, recon_image_path, output_dir=os.path.dirname(recon_image_path))
vmaf_output = VMAFOutput.parse_vmaf_json(file_path)
print(f"version: {vmaf_output.version}, (w, h): ({vmaf_output.sequence_width}, {vmaf_output.sequence_height}), "
f"fps: {vmaf_output.sequence_fps}")
print(f"VMAF: {vmaf_output.frames[0].psnr_y}")
print(f"PSNR: {vmaf_output.frames[0].vmaf}")
# """
parser = argparse.ArgumentParser(description="Runs VMAF CLI tool to compute VMAF and other metrics.",
epilog="e.g. python compute_vmaf.py "
"../vmaf "
"../datasets/cityscapes/train "
"../datasets/cityscapes/train_QP47")
parser.add_argument("vmaf_cli_path", type=str,
help="Path pointing to VMAF CLI.")
parser.add_argument("orig_images_dataset", type=str,
help="Path to original yuv images.")
parser.add_argument("recon_images_dataset", type=str,
help="Path to decoded/reconstructed yuv images. Expecting it to be a single folder containing "
"the decoded images with the name ending with `_bs_decoded.yuv`")
parser.add_argument("-j", "--jobs", type=int, default=multiprocessing.cpu_count(),
help=f"Number of parallel jobs. Defauls to the total number of CPUs.")
parser.add_argument("-o", "--output_dir", type=str, default=None,
help=f"If passed then its used to save all output files, else the same directory of the "
f"recon images is used.")
parser.add_argument("-ot", "--output_type", type=str, default="json",
help="Sets the VMAF output file type. Valid values are `XML`, `JSON`, `CSV` and `SUB`. "
"Defaults to JSON.")
parser.add_argument("-go", "--gen_output", action="store_true",
help="If passed output of VTM tool to stdout and stderr are stored in files.")
args = parser.parse_args()
# Scan original images and generates expected decoded file path
recon_files = []
images = []
for dirpath, dirnames, filenames in os.walk(args.orig_images_dataset): # noqa
for file in filenames:
if os.path.splitext(file)[-1] == ".yuv" and "_bs" not in file:
# Generates the decoded files path
recon_files.append(
os.path.join(
args.recon_images_dataset,
os.path.splitext(os.path.basename(file))[0]+"_bs_decoded.yuv"
)
)
# Generate image pairs as dicts
images.append(
{
"orig": os.path.join(dirpath, file),
"recon": recon_files[-1],
"valid": False
}
)
# Scan decoded images and check for the existing decoded file
for dirpath, dirnames, filenames in os.walk(args.recon_images_dataset): # noqa
for file in filenames:
if os.path.splitext(file)[-1] == ".yuv" and os.path.basename(file).endswith("_bs_decoded.yuv"):
indexes = []
for idx, rec_file in enumerate(recon_files):
if file in rec_file:
indexes.append(idx)
if len(indexes) < 1:
raise ValueError(f"Decoded image ({os.path.join(dirpath, file)}) didn't have a original "
f"correspondence under {args.orig_images_dataset}.")
if len(indexes) > 1:
raise ValueError(f"Decoded image ({os.path.join(dirpath, file)}) have multiple original "
f"correspondences under {args.orig_images_dataset}.")
images[indexes[0]]["valid"] = True
for obj in images:
if not obj["valid"]:
raise ValueError(f"Didn't find a valid image pair for: {obj['orig']}:{obj['recon']}")
pool = multiprocessing.Pool(processes=args.jobs)
compute_vmaf_wrapper_partial = partial(
compute_vmaf_wrapper,
vmaf_cli_path=args.vmaf_cli_path,
output_type=args.output_type,
output_dir=args.output_dir,
gen_output=args.gen_output
)
for _ in tqdm.tqdm(
pool.imap(compute_vmaf_wrapper_partial, images),
total=len(images)
):
pass
print(">> done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment