Created
December 2, 2021 17:50
-
-
Save pinxau1000/f3f854199c338b679ea0c13923513607 to your computer and use it in GitHub Desktop.
Python Script for VMAF Computation using VMAF CLI and Multithreading
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import multiprocessing | |
import os | |
import re | |
import subprocess | |
from functools import partial | |
import tqdm | |
def compute_vmaf(vmaf_cli_path: str, orig_image: str, recon_image: str, output_type: str = "json", | |
output_dir: str = None, save_output_to_file: bool = False, **kwargs): | |
match = re.search(r"([\d]+x[\d]+)", os.path.basename(orig_image)) | |
if match: | |
width, height = match.group().split("x") | |
else: | |
return None | |
pixel_format = os.path.splitext(os.path.basename(orig_image))[0].split("_")[-1].replace("yuv", "").replace("p", "") | |
if not (output_type == "json" or output_type == "xml" or output_type == "csv" or output_type == "sub"): | |
output_type = "json" # Defaults to json output file. | |
if not output_dir: | |
output_dir = os.path.dirname(recon_image) | |
os.makedirs(output_dir, exist_ok=True) | |
out_file_path = os.path.join( | |
output_dir, | |
os.path.splitext(os.path.basename(orig_image))[0] + f"_vmaf.{output_type}" | |
) | |
cmd = [ | |
vmaf_cli_path, | |
"-r", orig_image, | |
"-d", recon_image, | |
"-w", width, | |
"-h", height, | |
"-p", pixel_format, | |
"-b 8", | |
"--feature", "psnr", | |
"--feature", "psnr_hvs", | |
"--feature", "float_ssim", | |
"--feature", "float_ms_ssim", | |
"--feature", "ciede", | |
"--feature", "cambi", | |
f"--{output_type}", | |
"-o", out_file_path | |
] | |
# Add extra parameters to VTM DecoderAppStatic command | |
for k, v in kwargs.items(): # noqa | |
cmd += [ | |
str(k), str(v) | |
] | |
result = subprocess.run( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE | |
) | |
if save_output_to_file: | |
if not output_dir: | |
output_dir = os.path.dirname(orig_image) | |
save_path_no_ext = os.path.join( | |
output_dir, | |
os.path.splitext(os.path.basename(orig_image))[0] + "_vmaf" | |
) | |
with open(f"{save_path_no_ext}.stdout", "w") as fwriter: | |
fwriter.write(result.stdout.decode("utf-8")) | |
with open(f"{save_path_no_ext}.stderr", "w") as fwriter: | |
fwriter.write(result.stderr.decode("utf-8")) | |
return out_file_path | |
def compute_vmaf_wrapper(image_pairs: dict, vmaf_cli_path:str, output_type: str, output_dir: str, gen_output: bool): | |
return compute_vmaf( | |
vmaf_cli_path=vmaf_cli_path, | |
orig_image=image_pairs["orig"], | |
recon_image=image_pairs["recon"], | |
output_type=output_type, | |
output_dir=output_dir, | |
save_output_to_file=gen_output | |
) | |
if __name__ == "__main__": | |
""" | |
# e.g. | |
orig_image_path = "../datasets/test/val/munich/munich_000014_000019_leftImg8bit_2048x1024_yuv420p.yuv" | |
recon_image_path = "../datasets/test/val_QP47/munich_000014_000019_leftImg8bit_2048x1024_yuv420p_bs_decoded.yuv" | |
file_path = compute_vmaf(orig_image_path, recon_image_path, output_dir=os.path.dirname(recon_image_path)) | |
vmaf_output = VMAFOutput.parse_vmaf_json(file_path) | |
print(f"version: {vmaf_output.version}, (w, h): ({vmaf_output.sequence_width}, {vmaf_output.sequence_height}), " | |
f"fps: {vmaf_output.sequence_fps}") | |
print(f"VMAF: {vmaf_output.frames[0].psnr_y}") | |
print(f"PSNR: {vmaf_output.frames[0].vmaf}") | |
# """ | |
parser = argparse.ArgumentParser(description="Runs VMAF CLI tool to compute VMAF and other metrics.", | |
epilog="e.g. python compute_vmaf.py " | |
"../vmaf " | |
"../datasets/cityscapes/train " | |
"../datasets/cityscapes/train_QP47") | |
parser.add_argument("vmaf_cli_path", type=str, | |
help="Path pointing to VMAF CLI.") | |
parser.add_argument("orig_images_dataset", type=str, | |
help="Path to original yuv images.") | |
parser.add_argument("recon_images_dataset", type=str, | |
help="Path to decoded/reconstructed yuv images. Expecting it to be a single folder containing " | |
"the decoded images with the name ending with `_bs_decoded.yuv`") | |
parser.add_argument("-j", "--jobs", type=int, default=multiprocessing.cpu_count(), | |
help=f"Number of parallel jobs. Defauls to the total number of CPUs.") | |
parser.add_argument("-o", "--output_dir", type=str, default=None, | |
help=f"If passed then its used to save all output files, else the same directory of the " | |
f"recon images is used.") | |
parser.add_argument("-ot", "--output_type", type=str, default="json", | |
help="Sets the VMAF output file type. Valid values are `XML`, `JSON`, `CSV` and `SUB`. " | |
"Defaults to JSON.") | |
parser.add_argument("-go", "--gen_output", action="store_true", | |
help="If passed output of VTM tool to stdout and stderr are stored in files.") | |
args = parser.parse_args() | |
# Scan original images and generates expected decoded file path | |
recon_files = [] | |
images = [] | |
for dirpath, dirnames, filenames in os.walk(args.orig_images_dataset): # noqa | |
for file in filenames: | |
if os.path.splitext(file)[-1] == ".yuv" and "_bs" not in file: | |
# Generates the decoded files path | |
recon_files.append( | |
os.path.join( | |
args.recon_images_dataset, | |
os.path.splitext(os.path.basename(file))[0]+"_bs_decoded.yuv" | |
) | |
) | |
# Generate image pairs as dicts | |
images.append( | |
{ | |
"orig": os.path.join(dirpath, file), | |
"recon": recon_files[-1], | |
"valid": False | |
} | |
) | |
# Scan decoded images and check for the existing decoded file | |
for dirpath, dirnames, filenames in os.walk(args.recon_images_dataset): # noqa | |
for file in filenames: | |
if os.path.splitext(file)[-1] == ".yuv" and os.path.basename(file).endswith("_bs_decoded.yuv"): | |
indexes = [] | |
for idx, rec_file in enumerate(recon_files): | |
if file in rec_file: | |
indexes.append(idx) | |
if len(indexes) < 1: | |
raise ValueError(f"Decoded image ({os.path.join(dirpath, file)}) didn't have a original " | |
f"correspondence under {args.orig_images_dataset}.") | |
if len(indexes) > 1: | |
raise ValueError(f"Decoded image ({os.path.join(dirpath, file)}) have multiple original " | |
f"correspondences under {args.orig_images_dataset}.") | |
images[indexes[0]]["valid"] = True | |
for obj in images: | |
if not obj["valid"]: | |
raise ValueError(f"Didn't find a valid image pair for: {obj['orig']}:{obj['recon']}") | |
pool = multiprocessing.Pool(processes=args.jobs) | |
compute_vmaf_wrapper_partial = partial( | |
compute_vmaf_wrapper, | |
vmaf_cli_path=args.vmaf_cli_path, | |
output_type=args.output_type, | |
output_dir=args.output_dir, | |
gen_output=args.gen_output | |
) | |
for _ in tqdm.tqdm( | |
pool.imap(compute_vmaf_wrapper_partial, images), | |
total=len(images) | |
): | |
pass | |
print(">> done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment