Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active November 23, 2021 15:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/8c26a109d7718035162a6d5d138b5499 to your computer and use it in GitHub Desktop.
Save vfdev-5/8c26a109d7718035162a6d5d138b5499 to your computer and use it in GitHub Desktop.
from pathlib import Path
import PIL
from PIL import Image
import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.io as io
isize = 500
osize = 224
mean = (0.1, 0.1, 0.1)
std = (0.2, 0.2, 0.2)
img_path = Path(f"test_{isize}.jpg")
def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def torch_loader(path):
mode = io.ImageReadMode.RGB
return io.read_image(path, mode)
transforms = [
(pil_loader, f"PIL->Resize ({isize}->{osize}) ->Tensor->Norm", T.Compose([
T.Resize((osize, osize)), T.ToTensor(), T.Normalize(mean, std)
])),
(pil_loader, f"PIL->Tensor->Resize ({isize}->{osize}) ->Norm", T.Compose([
T.ToTensor(), T.Resize((osize, osize)), T.Normalize(mean, std)
])),
(torch_loader, f"PTH->DType->Resize ({isize}->{osize}) ->Norm", nn.Sequential(
T.ConvertImageDtype(torch.float), T.Resize((osize, osize)), T.Normalize(mean, std)
)),
(torch_loader, f"JIT: PTH->DType->Resize ({isize}->{osize}) ->Norm", torch.jit.script(nn.Sequential(
T.ConvertImageDtype(torch.float), T.Resize((osize, osize)), T.Normalize(mean, std)
))),
]
def run_bench(t):
min_run_time = 3
reader_fn, label, transform = t
results = []
for i in [1, 6]:
torch.set_num_threads(i)
results += [
benchmark.Timer(
stmt=f"transform( read('{img_path.as_posix()}') )",
globals={
"read": reader_fn,
"transform": transform,
},
num_threads=torch.get_num_threads(),
label="Benchmark reader+transformation",
description="Time",
sub_label=label,
).blocked_autorange(min_run_time=min_run_time),
]
return results
def main():
all_results = []
for t in transforms:
all_results += run_bench(t)
compare = benchmark.Compare(all_results)
compare.print()
def write_random_image(path, size=64, num_channels=3):
shape = [num_channels, size, size]
tensor = torch.randint(0, 256, size=shape, dtype=torch.uint8)
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
pil_img = Image.fromarray(data)
pil_img = pil_img.convert("RGB")
pil_img.save(path)
assert Path(path).exists(), path
if __name__ == "__main__":
print(f"Torch config: {torch.__config__.show()}")
print(f"Torch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"PIL version: {PIL.__version__}")
if not img_path.exists():
write_random_image(img_path, size=isize)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment