Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active May 4, 2023 14:35
Show Gist options
  • Save vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2 to your computer and use it in GitHub Desktop.
Save vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2 to your computer and use it in GitHub Desktop.
Torchvision resize uint8 benchmarks
import torch
from torch.utils.benchmark import Timer, Compare
from torchvision.transforms import functional as F_stable
from torchvision.transforms.v2 import functional as F_v2
from itertools import product
from functools import partial
from typing import List, Optional
import torch
from torch import Tensor
from torch.nn.functional import interpolate as torch_interpolate
from torchvision.transforms._functional_tensor import _cast_squeeze_in, _cast_squeeze_out, _assert_image_tensor
class F_stable_nightly:
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
# TODO: in v0.17, change the default to True. This will a private function
# by then, so we don't care about warning here.
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
interpolation = interpolation.value
if isinstance(size, tuple):
size = list(size)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
# We manually set it to False to avoid an error downstream in interpolate()
# This behaviour is documented: the parameter is irrelevant for modes
# that are not bilinear or bicubic. We used to raise an error here, but
# now we don't as True is the default.
antialias = False
acceptable_dtypes = [torch.float32, torch.float64]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, acceptable_dtypes)
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = torch_interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
debug = False
if debug:
min_run_time = 0.1
else:
min_run_time = 10
def gen_inputs():
make_arg_int = partial(torch.randint, 0, 256, dtype=torch.uint8)
shapes = (
(3, 400, 400),
(16, 3, 400, 400)
)
modes = [
F_stable.InterpolationMode.NEAREST,
F_stable.InterpolationMode.BILINEAR,
# F_stable.InterpolationMode.BICUBIC,
]
if not debug:
makers = (make_arg_int, )
# devices = ("cpu", "cuda")
devices = ("cpu", )
fns = ["resize", ]
# threads = (1, torch.get_num_threads())
threads = (1, )
else:
makers = (make_arg_int, )
devices = ("cpu", )
fns = ["resize", ]
threads = (1, )
for make, shape, device, fn_name, threads, mode in product(makers, shapes, devices, fns, threads, modes):
t1 = make(shape, device=device)
args = (64, )
kwargs = dict(interpolation=mode, antialias=True)
fn = getattr(F_v2, fn_name)
yield (f"{fn_name.capitalize()} {device} {t1.dtype} {mode}", str(tuple(shape)), threads, "v2", fn, t1, *args), kwargs
fn = getattr(F_stable, fn_name)
yield (f"{fn_name.capitalize()} {device} {t1.dtype} {mode}", f"{str(tuple(shape))}", threads, "stable", fn, t1, *args), kwargs
fn = getattr(F_stable_nightly, fn_name)
yield (f"{fn_name.capitalize()} {device} {t1.dtype} {mode}", f"{str(tuple(shape))}", threads, "nightly", fn, t1, *args), kwargs
def benchmark(label, sub_label, threads, tag, f, *args, **kwargs):
if debug:
f_ref = getattr(F_stable, f.__name__)
if f is not f_ref:
out = f(*args, **kwargs)
ref = f_ref(*args, **kwargs)
torch.testing.assert_close(ref, out, atol=1, rtol=0)
return Timer("f(*args, **kwargs)",
globals=locals(),
label=label,
description=f.__name__ + f" {tag}",
sub_label=sub_label,
num_threads=threads).blocked_autorange(min_run_time=min_run_time)
results = []
for args, kwargs in gen_inputs():
if debug:
print(args[:4], kwargs)
results.append(benchmark(*args, **kwargs))
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
# [----------- Resize cpu torch.uint8 InterpolationMode.NEAREST -----------]
# | resize v2 | resize stable | resize nightly
# 1 threads: ---------------------------------------------------------------
# (3, 400, 400) | 457 | 461 | 480
# (16, 3, 400, 400) | 6870 | 6850 | 10100
# Times are in microseconds (us).
# [---------- Resize cpu torch.uint8 InterpolationMode.BILINEAR -----------]
# | resize v2 | resize stable | resize nightly
# 1 threads: ---------------------------------------------------------------
# (3, 400, 400) | 326 | 329 | 844
# (16, 3, 400, 400) | 4380 | 4390 | 14800
# Times are in microseconds (us).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment