Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@ksauzz
Forked from kuenishi/hdfs_image_dataset.py
Created November 28, 2018 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ksauzz/8b17448d1c6eed6704716b9dfeb28981 to your computer and use it in GitHub Desktop.
Save ksauzz/8b17448d1c6eed6704716b9dfeb28981 to your computer and use it in GitHub Desktop.
import getpass
import os
import threading
import time
from urllib.parse import urlparse
import zipfile
import numpy
try:
from PIL import Image
from pyarrow import hdfs
available = True
except ImportError as e:
available = False
_import_error = e
import six
import chainer
FS = None
FSLOCK = threading.Lock()
COUNTER = 0
def _read_image_as_array(path, dtype):
global FS, FSLOCK
with FSLOCK:
assert FS is not None
with FS.open(path, 'rb') as fp:
try:
f = Image.open(fp)
image = numpy.asarray(f, dtype=dtype)
return image
finally:
f.close()
def _read_image_inzip_as_array(zipfile, path, dtype):
assert zipfile is not None
with zipfile.open(path, 'r') as fp:
try:
f = Image.open(fp)
image = numpy.asarray(f, dtype=dtype)
return image
finally:
f.close()
def setup_hdfs(host, port, user=None):
global FS, COUNTER, FSLOCK
if user is None:
user = getpass.getuser()
with FSLOCK:
FS = hdfs.connect(host, port, user=user)
COUNTER += 1
if FS is not None:
print('Connected to HDFS', host, port, 'as', user, 'at process', os.getpid(), 'counter =', COUNTER)
class ImageDataset(chainer.datasets.ImageDataset):
def __init__(self, paths, root='.', dtype=numpy.float32):
with FSLOCK:
assert FS is not None
super(ImageDataset, self).__init__(paths, root, dtype)
def get_example(self, i):
path = os.path.join(self._root, self._paths[i])
image = _read_image_as_array(path, self._dtype)
if image.ndim == 2:
# image is greyscale
image = image[:, :, numpy.newaxis]
return image.transpose(2, 0, 1)
class LabeledImageDataset(chainer.datasets.LabeledImageDataset):
def __init__(self, pairs, root, dtype=numpy.float32,
label_dtype=numpy.int32):
with FSLOCK:
assert FS is not None
super(LabeledImageDataset, self).__init__(pairs, root, dtype,
label_dtype)
def get_example(self, i):
path, int_label = self._pairs[i]
full_path = os.path.join(self._root, path)
image = _read_image_as_array(full_path, self._dtype)
if image.ndim == 2:
# image is greyscale
image = image[:, :, numpy.newaxis]
label = numpy.array(int_label, dtype=self._label_dtype)
return image.transpose(2, 0, 1), label
class ZippedImageDataset(chainer.datasets.ImageDataset):
def __init__(self, paths, root='.', dtype=numpy.float32):
assert root.startswith('hdfs://')
super(ZippedImageDataset, self).__init__(paths, root, dtype)
self._root = root
self._pid = os.getpid()
self._zipfile = None
def get_example(self, i):
global FS
if self._pid != os.getpid() or self._zipfile is None:
# overhead?
self._url = urlparse(root)
self._pid = os.getpid()
setup_hdfs(self._url.hostname, self._url.port)
b = time.time()
self._zipfile = zipfile.ZipFile(self._root)
e = time.time()
# print(e - b, "seconds to open", self._root)
path = self._paths[i]
image = _read_image_inzip_as_array(self._zipfile, path,
self._dtype)
if image.ndim == 2:
# image is greyscale
image = image[:, :, numpy.newaxis]
return image.transpose(2, 0, 1)
def finalize(self):
'''Note that iterator does not finalize datasets, so use this dataset at your own risk!'''
if self._zipfile:
self._zipfile.close()
class ZippedLabeledImageDataset(chainer.datasets.LabeledImageDataset):
def __init__(self, pairs, root, dtype=numpy.float32,
label_dtype=numpy.int32):
assert root.startswith('hdfs://')
super(ZippedLabeledImageDataset, self).__init__(pairs, root, dtype, label_dtype)
self._pid = os.getpid()
self._hdfsfile = None
self._zipfile = None
self._timing = []
def __reduce__(self):
return self.__class__, (self._pairs, self._root, self._dtype, self._label_dtype)
def get_example(self, i):
global FS
if self._pid != os.getpid() or self._zipfile is None or self._hdfsfile is None:
# overhead?
self._pid = os.getpid()
self._url = urlparse(self._root)
setup_hdfs(self._url.hostname, self._url.port)
with FSLOCK:
#b = time.time()
self._hdfsfile = FS.open(self._root, 'rb')
self._zipfile = zipfile.ZipFile(self._hdfsfile, 'r')
#e = time.time()
assert self._zipfile is not None
# print(e - b, "seconds to open", self._root)
path, int_label = self._pairs[i]
path = os.path.join('ILSVRC2012', path)
b = time.time()
image = _read_image_inzip_as_array(self._zipfile, path,
self._dtype)
e = time.time()
self._timing.append(e - b)
if image.ndim == 2:
# image is greyscale
image = image[:, :, numpy.newaxis]
label = numpy.array(int_label, dtype=self._label_dtype)
return image.transpose(2, 0, 1), label
def finalize(self):
'''Note that iterator does not finalize datasets, so use this dataset at your own risk!'''
if self._hdfsfile:
self._hdfsfile.close()
if self._zipfile:
self._zipfile.close()
def stats(self):
return numpy.average(self._timing), len(self._timing)
class ZippedLabeledImageDataset2(chainer.datasets.LabeledImageDataset):
''' root as zip file on hdfs, pairs are (label, internal-path) '''
def __init__(self, pairs, root, dtype=numpy.float32,
label_dtype=numpy.int32):
assert root.endswith('.zip')
super(ZippedLabeledImageDataset2, self).__init__(pairs, root, dtype, label_dtype)
self._pid = os.getpid()
self._zipfile = None
self._timing = []
def __reduce__(self):
return self.__class__, (self._pairs, self._root, self._dtype, self._label_dtype)
def get_example(self, i):
if self._pid != os.getpid() or self._zipfile is None:
# overhead?
self._pid = os.getpid()
#b = time.time()
self._zipfile = zipfile.ZipFile(self._root, 'r')
#e = time.time()
# print(e - b, "seconds to open", self._root, self._zipfile)
assert self._zipfile is not None
path, int_label = self._pairs[i]
path = os.path.join('ILSVRC2012', path)
b = time.time()
image = _read_image_inzip_as_array(self._zipfile, path,
self._dtype)
e = time.time()
self._timing.append(e - b)
if image.ndim == 2:
# image is greyscale
image = image[:, :, numpy.newaxis]
label = numpy.array(int_label, dtype=self._label_dtype)
return image.transpose(2, 0, 1), label
def finalize(self):
'''Note that iterator does not finalize datasets, so use this dataset at your own risk!'''
if self._zipfile:
self._zipfile.close()
def stats(self):
return numpy.average(self._timing), len(self._timing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment