Skip to content

Instantly share code, notes, and snippets.

@N-McA
Last active April 27, 2018 21:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save N-McA/54dc2e2ee9c6ad2c4d2a98d81345013d to your computer and use it in GitHub Desktop.
Save N-McA/54dc2e2ee9c6ad2c4d2a98d81345013d to your computer and use it in GitHub Desktop.
Trade memory for time when holding big stack of jpgs
'''
Compatible with Keras, faster than reading from files (no stats).
It's only designed to work if all your images are vaguely similar sizes/
when encoded as JPGS, so if you have:
white noise or other hard-to-encode stuff
radically varying image sizes
...
(probably other failure modes)
Then this is foolhardy.
On the plus side, you can do:
imgs = MultiJPG.from_file(...)
model.fit(imgs, labels)
and it'll work.
It has weird behavior tbh - if you call
imgs[:10]
you get a new MultiJPG
if you call
imgs[np.arange(10)] you get *the actual image data*.
The end result of that is that you can pass it in with a validation proportion in keras,
but tbh I'm not really sure if it's sane. I just wrote this because it seemed like a neat idea.
'''
from PIL import Image as pil_image
from io import BytesIO
import numpy as np
from keras.preprocessing.image import load_img, img_to_array, Iterator
def _img_to_jpeg_bytes(img: np.ndarray) -> bytes:
if img.shape[-1] == 1:
img = img.reshape(img.shape[:-1])
img = pil_image.fromarray(np.uint8(img * 255))
with BytesIO() as f:
img.save(f, format='JPEG', quality=90, subsampling=0)
encoded_x = f.getvalue()
return encoded_x
def _img_from_jpeg_bytes(bts, n_channels=3) -> np.ndarray:
with BytesIO(bts) as f:
img = img_to_array(load_img(f)) / 255
if n_channels == 1:
img = img[:, :, :1]
return img
class IncrementalMultiJPG:
def __init__(self):
self.byte_strings = []
def append(self, img):
self.byte_strings.append(_img_to_jpeg_bytes(img))
def finalize(self):
return MultiJPG.from_byte_strings(self.byte_strings)
class MultiJPG:
@classmethod
def from_imgs(cls, imgs):
byte_strings = []
for img in imgs:
byte_strings.append(_img_to_jpeg_bytes(img))
return MultiJPG.from_byte_strings(byte_strings)
@classmethod
def from_file(cls, fname):
return MultiJPG(_bytes=np.load(fname, mmap_mode='r'))
@classmethod
def from_byte_strings(cls, byte_strings):
block_size = max([len(bs) for bs in byte_strings])
f = BytesIO()
for bs in byte_strings:
f.write(bs)
padding_len = block_size - len(bs)
if padding_len > 0:
f.write(bytes(padding_len))
buffer = f.getvalue()
byte_array = np.frombuffer(buffer, dtype=np.uint8, count=len(buffer))
_bytes = byte_array.reshape([len(byte_strings), block_size])
return MultiJPG(_bytes=_bytes)
def __init__(self, *, _bytes=None):
if _bytes is None:
raise ValueError((
'Use a constructor from:\n'
'from_imgs(imgs)\n'
'from_file(fname)\n'
'Or use IncrementalMultiJPG to build incrementally'
))
self._bytes = _bytes
self.shape = (len(self), *self[0].shape)
self.ndim = 4
def __getitem__(self, s):
if isinstance(s, int):
return _img_from_jpeg_bytes(self._bytes[s])
if isinstance(s, slice):
return MultiJPG(_bytes=self._bytes[s])
if isinstance(s, np.ndarray) or isinstance(s, list):
if len(s) == len(self):
raise ValueError('Indexed MultiJPG with len(index) == len(self) == {}'.format(len(self)))
idxs = s
result = np.zeros([len(idxs), *self.shape[1:]])
for i, idx in enumerate(idxs):
result[i] = _img_from_jpeg_bytes(self._bytes[idx])
return result
raise ValueError('Cannot index MultiJPG with {} {}'.format(s.__class__, s))
def __len__(self):
return len(self._bytes)
def save(self, fname):
np.save(fname, self._bytes, allow_pickle=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment