Skip to content

Instantly share code, notes, and snippets.

@andfoy
Last active July 8, 2020 16:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andfoy/c9b75470c6557771d55d96ed40dbafd4 to your computer and use it in GitHub Desktop.
Save andfoy/c9b75470c6557771d55d96ed40dbafd4 to your computer and use it in GitHub Desktop.
import os
import unittest
import sys
import torch
import torchvision
from PIL import Image
from torchvision.io.image import read_png, decode_png, read_jpeg, decode_jpeg
import numpy as np
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
for fl in files:
_, ext = os.path.splitext(fl)
if ext == img_ext:
yield os.path.join(root, fl)
class ImageTester(unittest.TestCase):
def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_ljpeg = read_jpeg(img_path)
torch.save(img_pil, 'pil_img_dump.pth')
norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[2] * 255
err = torch.abs(img_ljpeg.flatten().float() - img_pil.flatten().float()).sum().float() / (norm)
print(err)
diff = (img_ljpeg.float() - img_pil.float()).abs().max()
print(diff)
self.assertTrue(img_ljpeg.equal(img_pil))
def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
size = os.path.getsize(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_ljpeg.equal(img_pil))
with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."):
decode_jpeg(torch.empty((100, ), dtype=torch.float16))
with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_lpng = read_png(img_path)
self.assertTrue(img_lpng.equal(img_pil))
def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_lpng.equal(img_pil))
with self.assertRaises(ValueError):
decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment