-
-
Save vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2 to your computer and use it in GitHub Desktop.
Torchvision resize uint8 benchmarks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.benchmark import Timer, Compare | |
from torchvision.transforms import functional as F_stable | |
from torchvision.transforms.v2 import functional as F_v2 | |
from itertools import product | |
from functools import partial | |
from typing import List, Optional | |
import torch | |
from torch import Tensor | |
from torch.nn.functional import interpolate as torch_interpolate | |
from torchvision.transforms._functional_tensor import _cast_squeeze_in, _cast_squeeze_out, _assert_image_tensor | |
class F_stable_nightly: | |
def resize( | |
img: Tensor, | |
size: List[int], | |
interpolation: str = "bilinear", | |
# TODO: in v0.17, change the default to True. This will a private function | |
# by then, so we don't care about warning here. | |
antialias: Optional[bool] = None, | |
) -> Tensor: | |
_assert_image_tensor(img) | |
interpolation = interpolation.value | |
if isinstance(size, tuple): | |
size = list(size) | |
if antialias is None: | |
antialias = False | |
if antialias and interpolation not in ["bilinear", "bicubic"]: | |
# We manually set it to False to avoid an error downstream in interpolate() | |
# This behaviour is documented: the parameter is irrelevant for modes | |
# that are not bilinear or bicubic. We used to raise an error here, but | |
# now we don't as True is the default. | |
antialias = False | |
acceptable_dtypes = [torch.float32, torch.float64] | |
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, acceptable_dtypes) | |
# Define align_corners to avoid warnings | |
align_corners = False if interpolation in ["bilinear", "bicubic"] else None | |
img = torch_interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) | |
if interpolation == "bicubic" and out_dtype == torch.uint8: | |
img = img.clamp(min=0, max=255) | |
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) | |
return img | |
debug = False | |
if debug: | |
min_run_time = 0.1 | |
else: | |
min_run_time = 10 | |
def gen_inputs(): | |
make_arg_int = partial(torch.randint, 0, 256, dtype=torch.uint8) | |
shapes = ( | |
(3, 400, 400), | |
(16, 3, 400, 400) | |
) | |
modes = [ | |
F_stable.InterpolationMode.NEAREST, | |
F_stable.InterpolationMode.BILINEAR, | |
# F_stable.InterpolationMode.BICUBIC, | |
] | |
if not debug: | |
makers = (make_arg_int, ) | |
# devices = ("cpu", "cuda") | |
devices = ("cpu", ) | |
fns = ["resize", ] | |
# threads = (1, torch.get_num_threads()) | |
threads = (1, ) | |
else: | |
makers = (make_arg_int, ) | |
devices = ("cpu", ) | |
fns = ["resize", ] | |
threads = (1, ) | |
for make, shape, device, fn_name, threads, mode in product(makers, shapes, devices, fns, threads, modes): | |
t1 = make(shape, device=device) | |
args = (64, ) | |
kwargs = dict(interpolation=mode, antialias=True) | |
fn = getattr(F_v2, fn_name) | |
yield (f"{fn_name.capitalize()} {device} {t1.dtype} {mode}", str(tuple(shape)), threads, "v2", fn, t1, *args), kwargs | |
fn = getattr(F_stable, fn_name) | |
yield (f"{fn_name.capitalize()} {device} {t1.dtype} {mode}", f"{str(tuple(shape))}", threads, "stable", fn, t1, *args), kwargs | |
fn = getattr(F_stable_nightly, fn_name) | |
yield (f"{fn_name.capitalize()} {device} {t1.dtype} {mode}", f"{str(tuple(shape))}", threads, "nightly", fn, t1, *args), kwargs | |
def benchmark(label, sub_label, threads, tag, f, *args, **kwargs): | |
if debug: | |
f_ref = getattr(F_stable, f.__name__) | |
if f is not f_ref: | |
out = f(*args, **kwargs) | |
ref = f_ref(*args, **kwargs) | |
torch.testing.assert_close(ref, out, atol=1, rtol=0) | |
return Timer("f(*args, **kwargs)", | |
globals=locals(), | |
label=label, | |
description=f.__name__ + f" {tag}", | |
sub_label=sub_label, | |
num_threads=threads).blocked_autorange(min_run_time=min_run_time) | |
results = [] | |
for args, kwargs in gen_inputs(): | |
if debug: | |
print(args[:4], kwargs) | |
results.append(benchmark(*args, **kwargs)) | |
compare = Compare(results) | |
compare.trim_significant_figures() | |
compare.print() | |
# [----------- Resize cpu torch.uint8 InterpolationMode.NEAREST -----------] | |
# | resize v2 | resize stable | resize nightly | |
# 1 threads: --------------------------------------------------------------- | |
# (3, 400, 400) | 457 | 461 | 480 | |
# (16, 3, 400, 400) | 6870 | 6850 | 10100 | |
# Times are in microseconds (us). | |
# [---------- Resize cpu torch.uint8 InterpolationMode.BILINEAR -----------] | |
# | resize v2 | resize stable | resize nightly | |
# 1 threads: --------------------------------------------------------------- | |
# (3, 400, 400) | 326 | 329 | 844 | |
# (16, 3, 400, 400) | 4380 | 4390 | 14800 | |
# Times are in microseconds (us). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment