Skip to content

Instantly share code, notes, and snippets.

@tok41
Last active April 26, 2019 05:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tok41/2903135861e738c182ef4d60e0fcff1b to your computer and use it in GitHub Desktop.
Save tok41/2903135861e738c182ef4d60e0fcff1b to your computer and use it in GitHub Desktop.
# PASCAL VOC形式のアノテーションデータ(xmlファイル)を取り込んで
# ABEJA Platform上にデータセットを作る
import os
from dotenv import load_dotenv
import numpy as np
import glob
import xmltodict
import abeja
from abeja.datalake import Client as DatalakeClient
from abeja.datasets import APIClient
from abejacli.config import (
ABEJA_PLATFORM_USER_ID,
ABEJA_PLATFORM_TOKEN
)
credential = {
'user_id': ABEJA_PLATFORM_USER_ID,
'personal_access_token': ABEJA_PLATFORM_TOKEN
}
def getClasses(classes_file):
"""
Args
classes_file : class definition file
Returns
classes : dict{id:class_name}
"""
classes = list()
with open(classes_file) as fd:
for one_line in fd.readlines():
cl = one_line.split('\n')[0]
classes.append(cl)
return classes
def getBBoxData(anno_file, classes):
"""
Args:
anno_file : file name of Annotation(xml)
classes : list of class ([string])
Returns:
dict{'image':filename,
'objs':list[{
label, label_id,
bbox{xmin, xmax, ymin, ymax}
}]}
"""
with open(anno_file) as fd:
pars = xmltodict.parse(fd.read())
ann_data = pars['annotation']
objs = list()
for obj in ann_data['object']:
label = obj['name']
label_id = classes.index(label)
bbox = obj['bndbox']
objs.append({'label': label,
'label_id': label_id,
'bbox': bbox})
return {'image': ann_data['filename'],
'objs': objs}
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--channel', required=False,
help='DATA LAKE CHANNEL ID')
parser.add_argument('--dataset', required=False, help='Dataset Name')
parser.add_argument('--annotation', required=False,
help='Directory Name of Annotation files')
parser.add_argument('--classes', required=False,
help='class definition file')
args = parser.parse_args()
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
load_dotenv(dotenv_path)
# ID
ORG_ID = os.environ.get("ORGANIZATION_ID")
channel_id = args.channel
classes = getClasses(args.classes)
labels = []
for cat in classes:
labels.append({
'label': cat,
'label_id': classes.index(cat)
})
datasetname = args.dataset
category = {
'labels': labels,
'category_id': 0,
'name': datasetname
}
props = {'categories': [category]}
# Create DataSet
dataset_client = APIClient(credential)
dataset_list = dataset_client.list_datasets(organization_id=ORG_ID)
for ds in dataset_list:
if datasetname in ds['name']:
dataset_client.delete_dataset(ORG_ID, ds['dataset_id'])
break
dataset = dataset_client.create_dataset(
ORG_ID, datasetname, 'detection', props)
dataset_id = dataset['dataset_id']
# Get Data List from Datalake
client = DatalakeClient(
organization_id=ORG_ID, credential=credential)
channel = client.get_channel(channel_id)
file_itr = channel.list_files()
anno_dir = args.annotation
anno_files = glob.glob(os.path.join(anno_dir, '*.xml'))
for f in file_itr:
f_info = f.get_file_info()
filename = f_info['metadata']['x-abeja-meta-filename']
file_id = f_info['file_id']
content_type = f_info['content_type']
f_xml = os.path.join(anno_dir, '{}.xml'.format(
os.path.splitext(filename)[0]))
if f_xml in anno_files:
dict_anno = getBBoxData(anno_file=f_xml, classes=classes)
info = []
for obj in dict_anno['objs']:
info.append({
'category_id': 0,
'label': obj['label'],
'label_id': int(obj['label_id']),
'rect': {
'xmin': obj['bbox']['xmin'],
'ymin': obj['bbox']['ymin'],
'xmax': obj['bbox']['xmax'],
'ymax': obj['bbox']['ymax'],
}})
else:
continue
# Create DatasetItem
data_uri = 'datalake://{}/{}'.format(channel_id, file_id)
source_data = [{'data_uri': data_uri, 'data_type': content_type}]
attributes = {'detection': info}
dataset_client.create_dataset_item(
ORG_ID, dataset_id, source_data, attributes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment