Skip to content

Instantly share code, notes, and snippets.

@tok41
Created April 26, 2019 06:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tok41/8315c82e1391f02be36118386acc718b to your computer and use it in GitHub Desktop.
Save tok41/8315c82e1391f02be36118386acc718b to your computer and use it in GitHub Desktop.
"""
dataset utils
"""
import io
import numpy as np
from PIL import Image
from abeja.datasets import Client
from chainercv.chainer_experimental.datasets.sliceable import GetterDataset
def load_dataset_from_api(dataset_id, organization_id=None, credential=None):
if organization_id is None:
client = Client()
else:
client = Client(organization_id=organization_id, credential=credential)
dataset = client.get_dataset(dataset_id)
dataset_list = dataset.dataset_items.list(prefetch=True)
return dataset_list
def load_classes(classes_file):
classes = list()
with open(classes_file) as fd:
for one_line in fd.readlines():
cl = one_line.split('\n')[0]
classes.append(cl)
return classes
class DetectionDatasetFromAPI(GetterDataset):
"""DetectionDatasetFromAPI
https://github.com/abeja-inc/abeja-platform-samples/blob/master/chainer/ssd/dataset.py
"""
def __init__(self, dataset_list, use_difficult=False, return_difficult=False):
super(DetectionDatasetFromAPI, self).__init__()
self.dataset_list = dataset_list
self.use_difficult = use_difficult
self.add_getter('img', self._get_image)
self.add_getter(('bbox', 'label', 'difficult'), self._get_annotations)
if not return_difficult:
self.keys = ('img', 'bbox', 'label')
def __len__(self):
return len(self.dataset_list)
def read_image_as_array(self, file_obj):
img = Image.open(file_obj)
try:
img = np.asarray(img, dtype=np.float32)
finally:
if hasattr(img, 'close'):
img.close()
img = img.transpose((2, 0, 1))
return img
def _get_image(self, i):
item = self.dataset_list[i]
file_content = item.source_data[0].get_content()
file_like_object = io.BytesIO(file_content)
img = self.read_image_as_array(file_like_object)
return img
def _get_annotations(self, i):
item = self.dataset_list[i]
annotations = item.attributes['detection']
bbox = []
label = []
difficult = []
for annotation in annotations:
rect = annotation['rect']
box = rect['ymin'], rect['xmin'], rect['ymax'], rect['xmax']
bbox.append(box)
label.append(annotation['label_id'])
bbox = np.stack(bbox).astype(np.float32)
label = np.stack(label).astype(np.int32)
difficult = np.array(difficult, dtype=np.bool)
return bbox, label, difficult
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment