Skip to content

Instantly share code, notes, and snippets.

@pmeier
Last active December 14, 2022 11:25
Show Gist options
  • Save pmeier/e0f1ea77c9cf75b682d7f30366a89bf8 to your computer and use it in GitHub Desktop.
Save pmeier/e0f1ea77c9cf75b682d7f30366a89bf8 to your computer and use it in GitHub Desktop.

Benchmark script and results for torchvision.transforms.functional v1 vs v2

Run benchmark.py to reproduce the results.

To add a new benchmark, add a BenchmarkConfig to BENCHMARK_CONFIGS in configs.py

[----------- adjust_brightness @ torchvision==0.15.0a0+b1f6c9e -----------]
| v1 | v2
1 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 540 (+- 2) | 286 (+- 4)
(3, 400, 400) / uint8 / cuda | 76 (+- 0) | 30 (+- 0)
(3, 400, 400) / PIL | 600 (+- 57) | 598 (+- 0)
(3, 400, 400) / float32 / cpu | 185 (+- 5) | 66 (+- 2)
(3, 400, 400) / float32 / cuda | 76 (+- 0) | 26 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 18817 (+- 66) | 8243 (+- 97)
(16, 3, 400, 400) / uint8 / cuda | 1108 (+- 0) | 561 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 13168 (+-126) | 3004 (+- 24)
(16, 3, 400, 400) / float32 / cuda | 1181 (+- 0) | 478 (+- 0)
6 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 149 (+- 0) | 76 (+- 1)
(3, 400, 400) / float32 / cpu | 72 (+- 0) | 26 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 8788 (+-102) | 1767 (+- 10)
(16, 3, 400, 400) / float32 / cpu | 9587 (+- 94) | 236 (+- 1)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -59.9% (improvement)
[------------ adjust_contrast @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 569 (+- 5) | 428 (+- 3)
(3, 400, 400) / uint8 / cuda | 103 (+- 1) | 74 (+- 0)
(3, 400, 400) / PIL | 789 (+-187) | 1166 (+- 1)
(3, 400, 400) / float32 / cpu | 241 (+- 2) | 165 (+- 2)
(3, 400, 400) / float32 / cuda | 81 (+- 0) | 60 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 19346 (+-743) | 15291 (+-214)
(16, 3, 400, 400) / uint8 / cuda | 1422 (+- 1) | 1218 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 11158 (+-282) | 8467 (+-131)
(16, 3, 400, 400) / float32 / cuda | 1269 (+- 1) | 1103 (+- 1)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 258 (+- 0) | 187 (+- 0)
(3, 400, 400) / float32 / cpu | 120 (+- 2) | 84 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 7483 (+-1566) | 5980 (+-244)
(16, 3, 400, 400) / float32 / cpu | 12332 (+-558) | 1819 (+- 10)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -15.3% (improvement)
[-------------- adjust_gamma @ torchvision==0.15.0a0+b1f6c9e --------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 2188 (+- 9) | 2126 (+- 10)
(3, 400, 400) / uint8 / cuda | 95 (+- 0) | 70 (+- 0)
(3, 400, 400) / PIL | 258 (+- 13) | 222 (+- 0)
(3, 400, 400) / float32 / cpu | 1822 (+- 4) | 1787 (+- 5)
(3, 400, 400) / float32 / cuda | 44 (+- 0) | 35 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 42383 (+-157) | 38352 (+- 99)
(16, 3, 400, 400) / uint8 / cuda | 1521 (+- 0) | 1519 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 31874 (+-119) | 29534 (+- 86)
(16, 3, 400, 400) / float32 / cuda | 718 (+- 0) | 718 (+- 0)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 441 (+- 5) | 406 (+- 1)
(3, 400, 400) / float32 / cpu | 341 (+- 1) | 323 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 11071 (+-100) | 7148 (+- 61)
(16, 3, 400, 400) / float32 / cpu | 5972 (+- 15) | 4942 (+- 8)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -13.7% (improvement)
[----------------- adjust_hue @ torchvision==0.15.0a0+b1f6c9e -----------------]
| v1 | v2
1 threads: ---------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 12618 (+-570) | 7870 (+- 49)
(3, 400, 400) / uint8 / cuda | 2068 (+- 28) | 742 (+- 3)
(3, 400, 400) / PIL | 5794 (+-169) | 5727 (+-155)
(3, 400, 400) / float32 / cpu | 12140 (+-938) | 8190 (+- 49)
(3, 400, 400) / float32 / cuda | 1983 (+- 4) | 695 (+- 3)
(16, 3, 400, 400) / uint8 / cpu | 364034 (+-4545) | 187053 (+-2614)
(16, 3, 400, 400) / uint8 / cuda | 32762 (+- 42) | 12431 (+- 15)
(16, 3, 400, 400) / float32 / cpu | 313904 (+-2877) | 188087 (+-1383)
(16, 3, 400, 400) / float32 / cuda | 31958 (+- 32) | 11628 (+- 11)
6 threads: ---------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 4414 (+- 46) | 4237 (+- 12)
(3, 400, 400) / float32 / cpu | 4318 (+- 7) | 3949 (+- 11)
(16, 3, 400, 400) / uint8 / cpu | 153182 (+-3217) | 125227 (+-1712)
(16, 3, 400, 400) / float32 / cpu | 143030 (+-2415) | 117018 (+-1404)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -19.7% (improvement)
[----------- adjust_saturation @ torchvision==0.15.0a0+b1f6c9e ------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 581 (+- 9) | 430 (+- 0)
(3, 400, 400) / uint8 / cuda | 83 (+- 0) | 69 (+- 0)
(3, 400, 400) / PIL | 678 (+- 1) | 663 (+- 0)
(3, 400, 400) / float32 / cpu | 250 (+- 0) | 153 (+- 0)
(3, 400, 400) / float32 / cuda | 76 (+- 0) | 60 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 25039 (+-283) | 15582 (+-118)
(16, 3, 400, 400) / uint8 / cuda | 1442 (+- 0) | 1193 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 12902 (+- 88) | 6944 (+- 60)
(16, 3, 400, 400) / float32 / cuda | 1381 (+- 0) | 1078 (+- 6)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 286 (+- 1) | 228 (+- 0)
(3, 400, 400) / float32 / cpu | 187 (+- 2) | 144 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 15721 (+- 73) | 5827 (+- 36)
(16, 3, 400, 400) / float32 / cpu | 8241 (+- 46) | 1627 (+- 10)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -21.8% (improvement)
[------------ adjust_sharpness @ torchvision==0.15.0a0+b1f6c9e ------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1445 (+- 9) | 1112 (+- 7)
(3, 400, 400) / uint8 / cuda | 256 (+- 1) | 208 (+- 0)
(3, 400, 400) / PIL | 2808 (+- 72) | 2687 (+- 4)
(3, 400, 400) / float32 / cpu | 1055 (+- 20) | 883 (+- 1)
(3, 400, 400) / float32 / cuda | 238 (+- 0) | 193 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 53049 (+-463) | 38889 (+-419)
(16, 3, 400, 400) / uint8 / cuda | 3336 (+- 0) | 3025 (+- 2)
(16, 3, 400, 400) / float32 / cpu | 46336 (+-335) | 35858 (+-312)
(16, 3, 400, 400) / float32 / cuda | 3077 (+- 3) | 2705 (+- 2)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1222 (+- 7) | 1100 (+- 3)
(3, 400, 400) / float32 / cpu | 1214 (+- 13) | 1066 (+- 5)
(16, 3, 400, 400) / uint8 / cpu | 38418 (+- 94) | 29784 (+- 79)
(16, 3, 400, 400) / float32 / cpu | 42614 (+- 83) | 29198 (+- 83)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -12.6% (improvement)
[------------ affine BILINEAR @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3186 (+- 5) | 3170 (+- 1)
(3, 400, 400) / uint8 / cuda | 376 (+- 0) | 382 (+- 1)
(3, 400, 400) / PIL | 2359 (+- 5) | 2547 (+- 4)
(3, 400, 400) / float32 / cpu | 2916 (+- 2) | 3000 (+- 6)
(3, 400, 400) / float32 / cuda | 330 (+- 2) | 326 (+- 2)
(16, 3, 400, 400) / uint8 / cpu | 36837 (+- 76) | 35710 (+- 29)
(16, 3, 400, 400) / uint8 / cuda | 1627 (+- 2) | 1625 (+- 3)
(16, 3, 400, 400) / float32 / cpu | 29615 (+-222) | 29641 (+- 24)
(16, 3, 400, 400) / float32 / cuda | 1051 (+- 3) | 1051 (+- 3)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3259 (+- 2) | 2586 (+- 15)
(3, 400, 400) / float32 / cpu | 2862 (+- 9) | 2854 (+- 66)
(16, 3, 400, 400) / uint8 / cpu | 9782 (+- 68) | 8062 (+-265)
(16, 3, 400, 400) / float32 / cpu | 6035 (+- 10) | 5996 (+- 10)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: +0.2% (slowdown)
[------------- affine NEAREST @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1936 (+- 4) | 2038 (+- 5)
(3, 400, 400) / uint8 / cuda | 353 (+- 1) | 355 (+- 1)
(3, 400, 400) / PIL | 229 (+- 0) | 228 (+- 1)
(3, 400, 400) / float32 / cpu | 1667 (+- 1) | 1668 (+- 1)
(3, 400, 400) / float32 / cuda | 309 (+- 1) | 304 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 16060 (+-106) | 15744 (+- 62)
(16, 3, 400, 400) / uint8 / cuda | 1315 (+- 3) | 1313 (+- 1)
(16, 3, 400, 400) / float32 / cpu | 9553 (+-177) | 9499 (+- 36)
(16, 3, 400, 400) / float32 / cuda | 745 (+- 3) | 737 (+- 6)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 2079 (+- 2) | 1343 (+- 2)
(3, 400, 400) / float32 / cpu | 1667 (+- 6) | 1667 (+- 3)
(16, 3, 400, 400) / uint8 / cpu | 6118 (+- 45) | 4463 (+- 58)
(16, 3, 400, 400) / float32 / cpu | 2187 (+- 23) | 2200 (+- 9)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -0.5% (improvement)
[-------------- autocontrast @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 600 (+- 1) | 380 (+- 3)
(3, 400, 400) / uint8 / cuda | 120 (+- 1) | 102 (+- 0)
(3, 400, 400) / PIL | 411 (+- 1) | 412 (+-568)
(3, 400, 400) / float32 / cpu | 165 (+- 0) | 173 (+- 2)
(3, 400, 400) / float32 / cuda | 104 (+- 1) | 87 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 15015 (+- 34) | 7420 (+-201)
(16, 3, 400, 400) / uint8 / cuda | 1020 (+- 1) | 1322 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 10094 (+-137) | 6947 (+- 72)
(16, 3, 400, 400) / float32 / cuda | 1001 (+- 1) | 993 (+- 2)
6 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 211 (+- 3) | 231 (+- 0)
(3, 400, 400) / float32 / cpu | 89 (+- 0) | 74 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 2970 (+- 20) | 1057 (+- 2)
(16, 3, 400, 400) / float32 / cpu | 4540 (+- 18) | 702 (+- 7)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -2.2% (improvement)
import contextlib
import gc
import itertools
import pathlib
import pickle
import traceback
import tqdm
import torch
import torchvision
from torch.utils import benchmark as torch_benchmark
from torchvision.transforms.functional import to_pil_image
from configs import BENCHMARK_CONFIGS, product_dict
from report import report, Setup
def main(configs, *, root, min_run_time=10, silently_fail_v1=True, log_to_file=True):
root = pathlib.Path(root)
root.mkdir(parents=True, exist_ok=True)
with tqdm.tqdm(total=len(configs)) as progress_bar:
for config in configs:
progress_bar.desc = config.label
progress_bar.display()
with suppress_with_traceback(Exception, msg=f"Something failed for {config} with the following traceback"):
measurements = benchmark(config, min_run_time=min_run_time, silently_fail_v1=silently_fail_v1)
normalized_title = measurements[0].task_spec.label.replace(" ", "_")
save(measurements, file_path=root / f"{normalized_title}.pkl")
report(
measurements, file_path=root / f"{normalized_title}.log" if log_to_file else None, aggregate=True
)
progress_bar.update()
def benchmark(config, *, min_run_time, silently_fail_v1):
title = f"{config.label} @ torchvision=={torchvision.__version__}"
measurements = []
for (version, fn), dct in tqdm.tqdm(
list(
itertools.product(
[("v1", config.v1_fn), ("v2", config.v2_fn)],
product_dict(
shape=config.shapes,
dtype=config.dtypes,
device=config.devices,
num_threads=config.cpu_threads,
indexed_kwargs=enumerate(config.kwargs),
),
)
)
):
dct = dct.copy()
idx, kwargs = dct.pop("indexed_kwargs")
setup = Setup(**dct, kwargs_idx=idx)
if setup.num_threads > 1 and setup.device != "cpu":
continue
input, row_label = make_input(setup.shape, dtype=setup.dtype, device=setup.device)
if input is None:
continue
# v1 has some limitations that we don't want to abort the whole benchmark over
with suppress_with_traceback(
Exception, msg="Something failed the v1 benchmark with the following traceback", silent=silently_fail_v1
) if version == "v1" else contextlib.nullcontext():
measurement = measure(
fn,
input,
kwargs,
num_threads=setup.num_threads,
title=title,
col_label=version,
row_label=row_label,
min_run_time=min_run_time,
)
measurement.setup = setup
measurements.append(measurement)
return measurements
def measure(fn, input, kwargs, *, num_threads, title, col_label, row_label, min_run_time):
timer = torch_benchmark.Timer(
stmt="fn(input, **kwargs)",
globals=dict(fn=fn, input=input, kwargs=kwargs),
num_threads=num_threads,
label=title,
sub_label=row_label,
description=col_label,
)
# Make it less likely that GC runs during benchmark
gc.collect()
return timer.blocked_autorange(min_run_time=min_run_time)
def save(measurements, *, file_path):
with open(file_path, "wb") as fh:
pickle.dump(measurements, fh)
@contextlib.contextmanager
def suppress_with_traceback(*exceptions, msg="Something failed with the following traceback", silent=False):
try:
yield
except exceptions as exc:
if silent:
return
print(f"{msg}:\n")
traceback.print_exception(type(exc), exc, exc.__traceback__)
print("\nContinuing as of nothing happened", flush=True)
def make_input(shape, *, dtype, device):
if device == "pil" and (len(shape) > 3 or dtype is not torch.uint8):
return None, ""
torch.manual_seed(0)
input = torch.testing.make_tensor(
shape,
dtype=dtype,
device=device if device != "pil" else "cpu",
low=0,
high=1.0 if dtype.is_floating_point else torch.iinfo(dtype).max + 1,
)
if device == "pil":
input = to_pil_image(input)
row_label = f"{shape!s} / PIL"
else:
row_label = f"{shape!s} / {str(dtype).replace('torch.', '')} / {device}"
return input, row_label
if __name__ == "__main__":
main(BENCHMARK_CONFIGS, root=pathlib.Path.cwd() / "transforms_v1_vs_v2")
[------------ center_crop @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 10 (+- 1) | 5 (+- 0)
(3, 400, 400) / uint8 / cuda | 9 (+- 1) | 4 (+- 0)
(3, 400, 400) / PIL | 15 (+- 0) | 11 (+- 1)
(3, 400, 400) / float32 / cpu | 11 (+- 1) | 5 (+- 0)
(3, 400, 400) / float32 / cuda | 9 (+- 0) | 4 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 9 (+- 1) | 5 (+- 0)
(16, 3, 400, 400) / uint8 / cuda | 11 (+- 1) | 4 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 11 (+- 0) | 5 (+- 0)
(16, 3, 400, 400) / float32 / cuda | 11 (+- 0) | 5 (+- 0)
6 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 9 (+- 0) | 4 (+- 0)
(3, 400, 400) / float32 / cpu | 11 (+- 1) | 5 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 11 (+- 1) | 5 (+- 1)
(16, 3, 400, 400) / float32 / cpu | 11 (+- 1) | 5 (+- 0)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -52.3% (improvement)
import dataclasses
import itertools
from functools import partial
from typing import *
import torch
from torchvision.prototype.datapoints import ColorSpace
from torchvision.prototype.transforms import functional as F_v2, InterpolationMode
from torchvision.transforms import functional as F_v1
SPATIAL_SIZE = (400, 400)
CPU_AND_MAYBE_CUDA = ["cpu", *(["cuda"] if torch.cuda.is_available() else [])]
@dataclasses.dataclass
class BenchmarkConfig:
label: str
v1_fn: Callable[[torch.Tensor], Any]
v2_fn: Callable[[torch.Tensor], Any]
shapes: List[Tuple[int, ...]] = dataclasses.field(
default_factory=lambda: [(3, *SPATIAL_SIZE), (16, 3, *SPATIAL_SIZE)]
)
dtypes: List[torch.dtype] = dataclasses.field(default_factory=lambda: [torch.uint8, torch.float32])
devices: List[str] = dataclasses.field(default_factory=lambda: [*CPU_AND_MAYBE_CUDA, "pil"])
cpu_threads: List[int] = dataclasses.field(default_factory=lambda: [1, 6])
kwargs: List[Dict[str, Any]] = dataclasses.field(default_factory=lambda: [dict()])
def product_dict(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
torch.manual_seed(0)
ELASTIC_DISPLACEMENT = torch.rand(1, *SPATIAL_SIZE, 2)
BENCHMARK_CONFIGS = [
BenchmarkConfig(
"adjust_contrast", v1_fn=F_v1.adjust_contrast, v2_fn=F_v2.adjust_contrast, kwargs=[dict(contrast_factor=0.5)]
),
BenchmarkConfig("adjust_hue", v1_fn=F_v1.adjust_hue, v2_fn=F_v2.adjust_hue, kwargs=[dict(hue_factor=0.3)]),
BenchmarkConfig(
"adjust_saturation",
v1_fn=F_v1.adjust_saturation,
v2_fn=F_v2.adjust_saturation,
kwargs=[dict(saturation_factor=0.5)],
),
BenchmarkConfig(
"adjust_sharpness",
v1_fn=F_v1.adjust_sharpness,
v2_fn=F_v2.adjust_sharpness,
kwargs=[dict(sharpness_factor=0.5)],
),
BenchmarkConfig("autocontrast", v1_fn=F_v1.autocontrast, v2_fn=F_v2.autocontrast),
BenchmarkConfig("equalize", v1_fn=F_v1.equalize, v2_fn=F_v2.equalize),
BenchmarkConfig(
"convert_color_space",
v1_fn=F_v1.rgb_to_grayscale,
v2_fn=partial(F_v2.convert_color_space, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB),
),
BenchmarkConfig(
"convert_dtype",
v1_fn=F_v1.convert_image_dtype,
v2_fn=F_v2.convert_dtype,
dtypes=[torch.uint8],
devices=CPU_AND_MAYBE_CUDA,
kwargs=[dict(dtype=torch.float32)],
),
BenchmarkConfig("gaussian_blur", v1_fn=F_v1.gaussian_blur, v2_fn=F_v2.gaussian_blur, kwargs=[dict(kernel_size=5)]),
BenchmarkConfig("invert", v1_fn=F_v1.invert, v2_fn=F_v2.invert),
BenchmarkConfig("solarize", v1_fn=F_v1.solarize, v2_fn=F_v2.solarize, kwargs=[dict(threshold=0.5)]),
BenchmarkConfig(
"adjust_brightness",
v1_fn=F_v1.adjust_brightness,
v2_fn=F_v2.adjust_brightness,
kwargs=[dict(brightness_factor=0.5)],
),
BenchmarkConfig(
"adjust_gamma", v1_fn=F_v1.adjust_gamma, v2_fn=F_v2.adjust_gamma, kwargs=[dict(gamma=0.9, gain=1.5)]
),
BenchmarkConfig(
"normalize",
v1_fn=F_v1.normalize,
v2_fn=F_v2.normalize,
dtypes=[torch.float32],
devices=CPU_AND_MAYBE_CUDA,
kwargs=[dict(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))],
),
BenchmarkConfig("posterize", v1_fn=F_v1.posterize, v2_fn=F_v2.posterize, kwargs=[dict(bits=3)]),
BenchmarkConfig(
"center_crop", v1_fn=F_v1.center_crop, v2_fn=F_v2.center_crop, kwargs=[dict(output_size=[224, 224])]
),
BenchmarkConfig(
"perspective",
v1_fn=F_v1.perspective,
v2_fn=F_v2.perspective,
kwargs=[dict(startpoints=[[0, 0], [33, 0], [33, 25], [0, 25]], endpoints=[[3, 2], [32, 3], [30, 24], [2, 25]])],
),
BenchmarkConfig(
"resize",
v1_fn=F_v1.resize,
v2_fn=F_v2.resize,
kwargs=product_dict(size=[[256, 256]], interpolation=[InterpolationMode.NEAREST, InterpolationMode.BILINEAR]),
),
BenchmarkConfig(
"resized_crop",
v1_fn=F_v1.resized_crop,
v2_fn=F_v2.resized_crop,
kwargs=product_dict(
top=[128],
left=[128],
height=[256],
width=[256],
size=[[512, 512]],
interpolation=[InterpolationMode.NEAREST, InterpolationMode.BILINEAR],
),
),
BenchmarkConfig(
"elastic", v1_fn=F_v1.elastic_transform, v2_fn=F_v2.elastic, kwargs=[dict(displacement=ELASTIC_DISPLACEMENT)]
),
BenchmarkConfig(
"affine",
v1_fn=F_v1.affine,
v2_fn=F_v2.affine,
kwargs=product_dict(
angle=[30],
translate=[[0, 0]],
scale=[1.0],
shear=[[0, 0]],
interpolation=[InterpolationMode.NEAREST, InterpolationMode.BILINEAR],
),
),
BenchmarkConfig("crop", v1_fn=F_v1.crop, v2_fn=F_v2.crop, kwargs=[dict(top=144, left=144, height=224, width=224)]),
BenchmarkConfig("five_crop", v1_fn=F_v1.five_crop, v2_fn=F_v2.five_crop, kwargs=[dict(size=224)]),
BenchmarkConfig("ten_crop", v1_fn=F_v1.ten_crop, v2_fn=F_v2.ten_crop, kwargs=[dict(size=224)]),
BenchmarkConfig(
"rotate",
v1_fn=F_v1.rotate,
v2_fn=F_v2.rotate,
kwargs=product_dict(angle=[30], interpolation=[InterpolationMode.NEAREST, InterpolationMode.BILINEAR]),
),
BenchmarkConfig("horizontal_flip", v1_fn=F_v1.hflip, v2_fn=F_v2.horizontal_flip),
BenchmarkConfig("vertical_flip", v1_fn=F_v1.vflip, v2_fn=F_v2.vertical_flip),
BenchmarkConfig("erase", v1_fn=F_v1.erase, v2_fn=F_v2.erase, kwargs=[dict(i=144, j=144, h=224, w=224, v=0)]),
BenchmarkConfig("pad", v1_fn=F_v1.pad, v2_fn=F_v2.pad, kwargs=[dict(padding=2)]),
]
[-- convert_color_space (RGB -> GRAY) @ torchvision==0.15.0a0+b1f6c9e ---]
| v1 | v2
1 threads: ---------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 194 (+- 3) | 155 (+- 2)
(3, 400, 400) / uint8 / cuda | 43 (+- 2) | 26 (+- 0)
(3, 400, 400) / PIL | 67 (+- 2) | 66 (+- 0)
(3, 400, 400) / float32 / cpu | 95 (+- 0) | 51 (+- 1)
(3, 400, 400) / float32 / cuda | 32 (+- 1) | 18 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 3393 (+- 20) | 2054 (+- 15)
(16, 3, 400, 400) / uint8 / cuda | 479 (+- 0) | 316 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 2535 (+- 40) | 1135 (+- 8)
(16, 3, 400, 400) / float32 / cuda | 481 (+- 0) | 318 (+- 0)
6 threads: ---------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 100 (+- 0) | 69 (+- 0)
(3, 400, 400) / float32 / cpu | 58 (+- 0) | 31 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 535 (+- 8) | 431 (+- 1)
(16, 3, 400, 400) / float32 / cpu | 327 (+- 2) | 139 (+- 1)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -34.3% (improvement)
[--- convert_dtype (float32 -> uint8) @ torchvision==0.15.0a0+b1f6c9e ---]
| v1 | v2
1 threads: ---------------------------------------------------------------
(3, 400, 400) / float32 / cpu | 238 (+- 0) | 231 (+- 1)
(3, 400, 400) / float32 / cuda | 26 (+- 0) | 26 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 5035 (+- 16) | 4990 (+- 23)
(16, 3, 400, 400) / float32 / cuda | 395 (+- 0) | 395 (+- 0)
6 threads: ---------------------------------------------------------------
(3, 400, 400) / float32 / cpu | 61 (+- 0) | 57 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 817 (+- 5) | 767 (+- 6)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -2.7% (improvement)
[-- convert_dtype (uint8 -> float32) @ torchvision==0.15.0a0+b1f6c9e --]
| v1 | v2
1 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 92 (+- 1) | 60 (+- 2)
(3, 400, 400) / uint8 / cuda | 25 (+- 0) | 22 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 3575 (+- 62) | 1460 (+- 14)
(16, 3, 400, 400) / uint8 / cuda | 404 (+- 0) | 404 (+- 0)
6 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 35 (+- 0) | 23 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 596 (+- 6) | 171 (+- 1)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -33.7% (improvement)
[---------------- crop @ torchvision==0.15.0a0+b1f6c9e ----------------]
| v1 | v2
1 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 6 (+- 0) | 3 (+- 0)
(3, 400, 400) / uint8 / cuda | 5 (+- 0) | 3 (+- 0)
(3, 400, 400) / PIL | 10 (+- 0) | 8 (+- 0)
(3, 400, 400) / float32 / cpu | 5 (+- 0) | 3 (+- 0)
(3, 400, 400) / float32 / cuda | 5 (+- 0) | 3 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 5 (+- 1) | 3 (+- 0)
(16, 3, 400, 400) / uint8 / cuda | 6 (+- 0) | 3 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 5 (+- 1) | 3 (+- 0)
(16, 3, 400, 400) / float32 / cuda | 5 (+- 0) | 4 (+- 0)
6 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 5 (+- 0) | 4 (+- 0)
(3, 400, 400) / float32 / cpu | 5 (+- 0) | 4 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 6 (+- 0) | 3 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 5 (+- 0) | 3 (+- 0)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -32.6% (improvement)
[---------------- elastic @ torchvision==0.15.0a0+b1f6c9e -----------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 2499 (+- 80) | 2591 (+- 6)
(3, 400, 400) / uint8 / cuda | 566 (+- 1) | 239 (+- 0)
(3, 400, 400) / PIL | 4240 (+- 99) | 4312 (+- 5)
(3, 400, 400) / float32 / cpu | 2382 (+- 4) | 2315 (+- 2)
(3, 400, 400) / float32 / cuda | 545 (+- 1) | 222 (+- 2)
(16, 3, 400, 400) / uint8 / cpu | 52825 (+-243) | 53608 (+-345)
(16, 3, 400, 400) / uint8 / cuda | 1922 (+- 2) | 1593 (+- 2)
(16, 3, 400, 400) / float32 / cpu | 45891 (+-421) | 47067 (+-143)
(16, 3, 400, 400) / float32 / cuda | 1368 (+- 7) | 1031 (+- 3)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3356 (+- 51) | 2908 (+- 55)
(3, 400, 400) / float32 / cpu | 2167 (+- 81) | 2135 (+- 2)
(16, 3, 400, 400) / uint8 / cpu | 11317 (+-106) | 10785 (+- 84)
(16, 3, 400, 400) / float32 / cpu | 7678 (+- 58) | 7355 (+- 37)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -6.7% (improvement)
[----------------- equalize @ torchvision==0.15.0a0+b1f6c9e -----------------]
| v1 | v2
1 threads: -------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1693 (+- 4) | 1515 (+-630)
(3, 400, 400) / uint8 / cuda | 615 (+- 2) | 287 (+- 0)
(3, 400, 400) / PIL | 389 (+- 0) | 387 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 27020 (+- 25) | 47607 (+- 93)
(16, 3, 400, 400) / uint8 / cuda | 9647 (+-192) | 2751 (+- 9)
(3, 400, 400) / float32 / cpu | | 1806 (+- 2)
(3, 400, 400) / float32 / cuda | | 348 (+- 0)
(16, 3, 400, 400) / float32 / cpu | | 56225 (+-11182)
(16, 3, 400, 400) / float32 / cuda | | 3557 (+- 7)
6 threads: -------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 791 (+- 3) | 982 (+- 40)
(16, 3, 400, 400) / uint8 / cpu | 18263 (+-246) | 9063 (+- 63)
(3, 400, 400) / float32 / cpu | | 1079 (+- 6)
(16, 3, 400, 400) / float32 / cpu | | 17992 (+-548)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -47.0% (improvement)
[------------- five_crop @ torchvision==0.15.0a0+b1f6c9e --------------]
| v1 | v2
1 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 38 (+- 0) | 15 (+- 0)
(3, 400, 400) / uint8 / cuda | 38 (+- 0) | 18 (+- 0)
(3, 400, 400) / PIL | 71 (+- 0) | 55 (+- 1)
(3, 400, 400) / float32 / cpu | 37 (+- 0) | 18 (+- 0)
(3, 400, 400) / float32 / cuda | 32 (+- 4) | 16 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 38 (+- 0) | 16 (+- 0)
(16, 3, 400, 400) / uint8 / cuda | 38 (+- 0) | 16 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 38 (+- 0) | 16 (+- 1)
(16, 3, 400, 400) / float32 / cuda | 30 (+- 0) | 18 (+- 1)
6 threads: -------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 37 (+- 4) | 18 (+- 0)
(3, 400, 400) / float32 / cpu | 29 (+- 4) | 15 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 38 (+- 0) | 16 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 38 (+- 0) | 18 (+- 1)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -55.1% (improvement)
[-------------- gaussian_blur @ torchvision==0.15.0a0+b1f6c9e --------------]
| v1 | v2
1 threads: ------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1937 (+- 3) | 1908 (+- 3)
(3, 400, 400) / uint8 / cuda | 368 (+- 1) | 313 (+- 0)
(3, 400, 400) / PIL | 4302 (+- 5) | 3786 (+-210)
(3, 400, 400) / float32 / cpu | 2125 (+- 3) | 1640 (+- 2)
(3, 400, 400) / float32 / cuda | 332 (+- 1) | 278 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 57534 (+-3550) | 50031 (+- 36)
(16, 3, 400, 400) / uint8 / cuda | 4718 (+- 3) | 4664 (+- 1)
(16, 3, 400, 400) / float32 / cpu | 49613 (+-3496) | 42184 (+- 59)
(16, 3, 400, 400) / float32 / cuda | 4155 (+- 3) | 4103 (+- 4)
6 threads: ------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1068 (+- 4) | 1000 (+- 3)
(3, 400, 400) / float32 / cpu | 782 (+- 70) | 751 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 31630 (+-392) | 29284 (+- 76)
(16, 3, 400, 400) / float32 / cpu | 24992 (+- 96) | 25460 (+- 70)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -2.3% (improvement)
[---------------- invert @ torchvision==0.15.0a0+b1f6c9e ----------------]
| v1 | v2
1 threads: ---------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 147 (+- 0) | 11 (+- 0)
(3, 400, 400) / uint8 / cuda | 10 (+- 0) | 4 (+- 0)
(3, 400, 400) / PIL | 167 (+- 0) | 165 (+- 0)
(3, 400, 400) / float32 / cpu | 42 (+- 0) | 39 (+- 1)
(3, 400, 400) / float32 / cuda | 15 (+- 0) | 15 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 2201 (+- 1) | 130 (+- 1)
(16, 3, 400, 400) / uint8 / cuda | 60 (+- 0) | 60 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 2045 (+- 4) | 2030 (+- 14)
(16, 3, 400, 400) / float32 / cuda | 239 (+- 0) | 239 (+- 0)
6 threads: ---------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 36 (+- 0) | 6 (+- 0)
(3, 400, 400) / float32 / cpu | 19 (+- 0) | 16 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 391 (+- 0) | 31 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 149 (+- 1) | 158 (+- 2)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -0.9% (improvement)
[-------------- normalize @ torchvision==0.15.0a0+b1f6c9e ---------------]
| v1 | v2
1 threads: ---------------------------------------------------------------
(3, 400, 400) / float32 / cpu | 128 (+- 1) | 93 (+- 2)
(3, 400, 400) / float32 / cuda | 91 (+- 0) | 54 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 3528 (+- 26) | 2507 (+- 9)
(16, 3, 400, 400) / float32 / cuda | 764 (+- 2) | 501 (+- 1)
6 threads: ---------------------------------------------------------------
(3, 400, 400) / float32 / cpu | 54 (+- 0) | 36 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 381 (+- 3) | 289 (+- 3)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -30.1% (improvement)
[-------------- perspective @ torchvision==0.15.0a0+b1f6c9e ---------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 4166 (+- 24) | 4112 (+- 2)
(3, 400, 400) / uint8 / cuda | 593 (+- 2) | 585 (+- 2)
(3, 400, 400) / PIL | 1173 (+- 1) | 1174 (+- 1)
(3, 400, 400) / float32 / cpu | 3997 (+- 75) | 3914 (+- 2)
(3, 400, 400) / float32 / cuda | 548 (+- 2) | 530 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 39752 (+-119) | 38303 (+-188)
(16, 3, 400, 400) / uint8 / cuda | 1665 (+- 1) | 1655 (+- 1)
(16, 3, 400, 400) / float32 / cpu | 32773 (+-198) | 32349 (+-197)
(16, 3, 400, 400) / float32 / cuda | 1080 (+- 8) | 1075 (+- 2)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3492 (+- 7) | 3039 (+- 17)
(3, 400, 400) / float32 / cpu | 2915 (+- 20) | 3082 (+- 13)
(16, 3, 400, 400) / uint8 / cpu | 10134 (+- 96) | 8375 (+- 45)
(16, 3, 400, 400) / float32 / cpu | 6187 (+- 22) | 6453 (+- 24)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -0.8% (improvement)
[-------------- posterize @ torchvision==0.15.0a0+b1f6c9e ---------------]
| v1 | v2
1 threads: ---------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 105 (+- 0) | 101 (+- 0)
(3, 400, 400) / uint8 / cuda | 10 (+- 0) | 6 (+- 0)
(3, 400, 400) / PIL | 173 (+- 0) | 171 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 1514 (+- 3) | 1522 (+- 3)
(16, 3, 400, 400) / uint8 / cuda | 60 (+- 0) | 60 (+- 0)
(3, 400, 400) / float32 / cpu | | 118 (+- 0)
(3, 400, 400) / float32 / cuda | | 40 (+- 0)
(16, 3, 400, 400) / float32 / cpu | | 4577 (+-126)
(16, 3, 400, 400) / float32 / cuda | | 955 (+- 0)
6 threads: ---------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 28 (+- 0) | 25 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 276 (+- 0) | 272 (+- 0)
(3, 400, 400) / float32 / cpu | | 43 (+- 0)
(16, 3, 400, 400) / float32 / cpu | | 388 (+- 5)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -2.2% (improvement)
import contextlib
import dataclasses
import pathlib
import pickle
import sys
import unittest.mock
from collections import defaultdict
from statistics import median, mean
from typing import *
import tabulate
import torch
from torch.utils import benchmark as torch_benchmark
def main(paths, aggregate=False):
measurements = []
for path in paths:
with open(path, "rb") as file:
measurements.extend(pickle.load(file))
report(measurements, aggregate=aggregate)
@dataclasses.dataclass(unsafe_hash=True)
class Setup:
shape: Tuple[int, ...]
dtype: torch.dtype
device: str
num_threads: int
kwargs_idx: int
def report(measurements, *, file_path=None, aggregate=False):
@contextlib.contextmanager
def log():
if file_path is None:
yield
return
with open(file_path, "w") as file, contextlib.redirect_stdout(file):
yield
def format_change(change):
return f"{change:+.1%} ({'improvement' if change >= 0 else 'slowdown'})"
with log():
with display_spread():
comparison = torch_benchmark.Compare(measurements)
comparison.print()
grouped_changes = compute_grouped_changes(measurements)
if aggregate:
print(f"Aggregated performance change of v2 vs. v1: {format_change(mean(grouped_changes.values()))}")
else:
table = tabulate.tabulate(
[
(setup, format_change(change))
for setup, change in sorted(grouped_changes.items(), key=lambda key_value: key_value[1])
],
headers=("setup", "performance change"),
tablefmt="github",
)
print(f"Performance changes v2 vs. v1:\n\n{table}")
@contextlib.contextmanager
def display_spread():
with unittest.mock.patch(
"torch.utils.benchmark.utils.compare._Row.as_column_strings", new=patched_as_column_strings
):
yield
def patched_as_column_strings(self):
concrete_results = [r for r in self._results if r is not None]
env = f"({concrete_results[0].env})" if self._render_env else ""
env = env.ljust(self._env_str_len + 4)
output = [" " + env + concrete_results[0].as_row_name]
for m, col in zip(self._results, self._columns or ()):
if m is None:
output.append(col.num_to_str(None, 1, None))
else:
if len(m.times) == 1:
spread = 0
else:
spread = float(torch.tensor(m.times, dtype=torch.float64).std(unbiased=len(m.times) > 1))
if col._trim_significant_figures:
spread = torch_benchmark.utils.common.trim_sigfig(spread, m.significant_figures)
output.append(f"{m.median / self._time_scale:>3.0f} (+-{spread / self._time_scale:>3.0f})")
return output
def compute_grouped_changes(measurements):
measurements_v1 = {}
measurements_v2 = {}
for m in measurements:
dct = measurements_v1 if m.task_spec.description == "v1" else measurements_v2
key = (m.task_spec.label, m.setup)
assert key not in dct
dct[key] = m
changes = {
key: 1 - measurements_v2[key].median / measurement_v1.median for key, measurement_v1 in measurements_v1.items()
}
grouped_changes = defaultdict(list)
for (_, setup), change in changes.items():
group = (setup.device, setup.dtype)
grouped_changes[group].append(change)
return {key: median(values) for key, values in grouped_changes.items()}
if __name__ == "__main__":
paths = [pathlib.Path(path).resolve() for path in sys.argv[1:]]
if not paths:
paths = list((pathlib.Path.cwd() / "transforms_v1_vs_v2").glob("*.pkl"))
main(paths)
[------------ resize BILINEAR @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1315 (+- 6) | 1440 (+- 3)
(3, 400, 400) / uint8 / cuda | 33 (+- 0) | 32 (+- 0)
(3, 400, 400) / PIL | 888 (+- 1) | 893 (+- 1)
(3, 400, 400) / float32 / cpu | 1312 (+- 54) | 1299 (+- 68)
(3, 400, 400) / float32 / cuda | 17 (+- 0) | 17 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 23415 (+-288) | 23270 (+-144)
(16, 3, 400, 400) / uint8 / cuda | 521 (+- 0) | 521 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 19306 (+-191) | 19280 (+-232)
(16, 3, 400, 400) / float32 / cuda | 191 (+- 0) | 191 (+- 0)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 199 (+- 0) | 200 (+- 0)
(3, 400, 400) / float32 / cpu | 160 (+- 1) | 158 (+- 2)
(16, 3, 400, 400) / uint8 / cpu | 2573 (+- 11) | 2517 (+- 10)
(16, 3, 400, 400) / float32 / cpu | 2079 (+- 11) | 2100 (+- 9)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: +0.2% (slowdown)
[------------- resize NEAREST @ torchvision==0.15.0a0+b1f6c9e ------------]
| v1 | v2
1 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 256 (+- 4) | 252 (+- 2)
(3, 400, 400) / uint8 / cuda | 32 (+- 3) | 27 (+- 0)
(3, 400, 400) / PIL | 46 (+- 2) | 39 (+- 1)
(3, 400, 400) / float32 / cpu | 121 (+- 1) | 119 (+- 0)
(3, 400, 400) / float32 / cuda | 15 (+- 2) | 12 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 6393 (+-1301) | 5862 (+- 64)
(16, 3, 400, 400) / uint8 / cuda | 462 (+- 0) | 462 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 2569 (+- 19) | 2567 (+- 31)
(16, 3, 400, 400) / float32 / cuda | 132 (+- 0) | 132 (+- 0)
6 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 135 (+- 2) | 126 (+- 2)
(3, 400, 400) / float32 / cpu | 95 (+- 2) | 88 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 1597 (+- 8) | 1476 (+- 5)
(16, 3, 400, 400) / float32 / cpu | 1112 (+- 4) | 1111 (+- 1)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -2.1% (improvement)
[---------- resized_crop BILINEAR @ torchvision==0.15.0a0+b1f6c9e ----------]
| v1 | v2
1 threads: ------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3884 (+- 6) | 4453 (+- 6)
(3, 400, 400) / uint8 / cuda | 86 (+- 0) | 84 (+- 0)
(3, 400, 400) / PIL | 1499 (+- 1) | 1502 (+- 1)
(3, 400, 400) / float32 / cpu | 3508 (+- 8) | 3503 (+-149)
(3, 400, 400) / float32 / cuda | 36 (+- 0) | 36 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 69721 (+-713) | 68454 (+-2048)
(16, 3, 400, 400) / uint8 / cuda | 1078 (+- 2) | 1077 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 59214 (+-881) | 58483 (+-3347)
(16, 3, 400, 400) / float32 / cuda | 350 (+- 0) | 350 (+- 0)
6 threads: ------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 651 (+- 3) | 628 (+- 0)
(3, 400, 400) / float32 / cpu | 554 (+- 4) | 547 (+- 3)
(16, 3, 400, 400) / uint8 / cpu | 13085 (+-103) | 9515 (+- 36)
(16, 3, 400, 400) / float32 / cpu | 8154 (+- 32) | 8161 (+- 29)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: +0.0% (slowdown)
[---------- resized_crop NEAREST @ torchvision==0.15.0a0+b1f6c9e ----------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 790 (+- 2) | 1984 (+-604)
(3, 400, 400) / uint8 / cuda | 69 (+- 0) | 67 (+- 0)
(3, 400, 400) / PIL | 148 (+- 0) | 142 (+- 0)
(3, 400, 400) / float32 / cpu | 393 (+- 0) | 387 (+- 1)
(3, 400, 400) / float32 / cuda | 30 (+- 0) | 25 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 16319 (+-167) | 14447 (+- 20)
(16, 3, 400, 400) / uint8 / cuda | 984 (+- 0) | 984 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 6085 (+- 18) | 6125 (+- 55)
(16, 3, 400, 400) / float32 / cuda | 356 (+- 0) | 356 (+- 0)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 394 (+- 3) | 376 (+- 0)
(3, 400, 400) / float32 / cpu | 304 (+- 3) | 293 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 9265 (+-611) | 5641 (+- 16)
(16, 3, 400, 400) / float32 / cpu | 4328 (+- 7) | 4329 (+- 2)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -1.3% (improvement)
[------------ rotate BILINEAR @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3147 (+- 44) | 3173 (+- 2)
(3, 400, 400) / uint8 / cuda | 375 (+- 2) | 377 (+- 1)
(3, 400, 400) / PIL | 2484 (+- 18) | 2324 (+- 2)
(3, 400, 400) / float32 / cpu | 2977 (+- 5) | 2909 (+- 3)
(3, 400, 400) / float32 / cuda | 330 (+- 1) | 326 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 37085 (+-101) | 35838 (+-110)
(16, 3, 400, 400) / uint8 / cuda | 1622 (+- 3) | 1617 (+- 2)
(16, 3, 400, 400) / float32 / cpu | 29632 (+- 29) | 29679 (+- 40)
(16, 3, 400, 400) / float32 / cuda | 1049 (+- 10) | 1047 (+- 5)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 3298 (+- 3) | 2610 (+- 19)
(3, 400, 400) / float32 / cpu | 2857 (+- 10) | 2853 (+- 4)
(16, 3, 400, 400) / uint8 / cpu | 9776 (+- 77) | 8128 (+-118)
(16, 3, 400, 400) / float32 / cpu | 5987 (+- 12) | 5986 (+- 16)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -0.8% (improvement)
[------------- rotate NEAREST @ torchvision==0.15.0a0+b1f6c9e -------------]
| v1 | v2
1 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 2045 (+- 41) | 1941 (+- 16)
(3, 400, 400) / uint8 / cuda | 351 (+- 1) | 353 (+- 1)
(3, 400, 400) / PIL | 238 (+- 1) | 237 (+- 8)
(3, 400, 400) / float32 / cpu | 1672 (+- 1) | 1675 (+- 19)
(3, 400, 400) / float32 / cuda | 311 (+- 1) | 306 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 16963 (+- 55) | 16315 (+- 63)
(16, 3, 400, 400) / uint8 / cuda | 1300 (+- 0) | 1293 (+- 2)
(16, 3, 400, 400) / float32 / cpu | 9209 (+- 28) | 9077 (+- 35)
(16, 3, 400, 400) / float32 / cuda | 728 (+- 6) | 725 (+- 3)
6 threads: -----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1892 (+- 79) | 1354 (+- 8)
(3, 400, 400) / float32 / cpu | 1510 (+- 82) | 1666 (+- 8)
(16, 3, 400, 400) / uint8 / cpu | 6114 (+- 52) | 4867 (+- 99)
(16, 3, 400, 400) / float32 / cpu | 2167 (+- 11) | 2211 (+- 7)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -0.3% (improvement)
[----------------- solarize @ torchvision==0.15.0a0+b1f6c9e ----------------]
| v1 | v2
1 threads: ------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 1481 (+- 1) | 622 (+-359)
(3, 400, 400) / uint8 / cuda | 22 (+- 2) | 15 (+- 0)
(3, 400, 400) / PIL | 170 (+- 0) | 198 (+- 0)
(3, 400, 400) / float32 / cpu | 1640 (+- 1) | 2158 (+- 2)
(3, 400, 400) / float32 / cuda | 50 (+- 0) | 51 (+- 0)
(16, 3, 400, 400) / uint8 / cpu | 12707 (+- 22) | 11026 (+- 84)
(16, 3, 400, 400) / uint8 / cuda | 309 (+- 0) | 310 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 34022 (+-4072) | 27862 (+- 22)
(16, 3, 400, 400) / float32 / cuda | 771 (+- 0) | 772 (+- 1)
6 threads: ------------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 283 (+- 0) | 252 (+- 1)
(3, 400, 400) / float32 / cpu | 365 (+- 29) | 358 (+- 21)
(16, 3, 400, 400) / uint8 / cpu | 4128 (+-866) | 3789 (+- 3)
(16, 3, 400, 400) / float32 / cpu | 7230 (+-662) | 7181 (+- 20)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: +0.2% (slowdown)
[---------------- ten_crop @ torchvision==0.15.0a0+b1f6c9e ---------------]
| v1 | v2
1 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 615 (+-351) | 255 (+- 0)
(3, 400, 400) / uint8 / cuda | 75 (+- 9) | 41 (+- 1)
(3, 400, 400) / PIL | 199 (+- 10) | 170 (+- 0)
(3, 400, 400) / float32 / cpu | 289 (+- 6) | 259 (+- 0)
(3, 400, 400) / float32 / cuda | 90 (+- 0) | 40 (+- 1)
(16, 3, 400, 400) / uint8 / cpu | 3580 (+-5070) | 3561 (+- 1)
(16, 3, 400, 400) / uint8 / cuda | 105 (+- 1) | 106 (+- 0)
(16, 3, 400, 400) / float32 / cpu | 4833 (+- 38) | 4832 (+- 20)
(16, 3, 400, 400) / float32 / cuda | 241 (+- 0) | 241 (+- 0)
6 threads: ----------------------------------------------------------------
(3, 400, 400) / uint8 / cpu | 233 (+- 0) | 201 (+- 0)
(3, 400, 400) / float32 / cpu | 128 (+- 0) | 79 (+- 3)
(16, 3, 400, 400) / uint8 / cpu | 2663 (+- 12) | 2640 (+-852)
(16, 3, 400, 400) / float32 / cpu | 716 (+- 7) | 663 (+- 5)
Times are in microseconds (us).
Aggregated performance change of v2 vs. v1: -14.6% (improvement)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment