Last active
April 26, 2019 05:05
-
-
Save tok41/2903135861e738c182ef4d60e0fcff1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# PASCAL VOC形式のアノテーションデータ(xmlファイル)を取り込んで | |
# ABEJA Platform上にデータセットを作る | |
import os | |
from dotenv import load_dotenv | |
import numpy as np | |
import glob | |
import xmltodict | |
import abeja | |
from abeja.datalake import Client as DatalakeClient | |
from abeja.datasets import APIClient | |
from abejacli.config import ( | |
ABEJA_PLATFORM_USER_ID, | |
ABEJA_PLATFORM_TOKEN | |
) | |
credential = { | |
'user_id': ABEJA_PLATFORM_USER_ID, | |
'personal_access_token': ABEJA_PLATFORM_TOKEN | |
} | |
def getClasses(classes_file): | |
""" | |
Args | |
classes_file : class definition file | |
Returns | |
classes : dict{id:class_name} | |
""" | |
classes = list() | |
with open(classes_file) as fd: | |
for one_line in fd.readlines(): | |
cl = one_line.split('\n')[0] | |
classes.append(cl) | |
return classes | |
def getBBoxData(anno_file, classes): | |
""" | |
Args: | |
anno_file : file name of Annotation(xml) | |
classes : list of class ([string]) | |
Returns: | |
dict{'image':filename, | |
'objs':list[{ | |
label, label_id, | |
bbox{xmin, xmax, ymin, ymax} | |
}]} | |
""" | |
with open(anno_file) as fd: | |
pars = xmltodict.parse(fd.read()) | |
ann_data = pars['annotation'] | |
objs = list() | |
for obj in ann_data['object']: | |
label = obj['name'] | |
label_id = classes.index(label) | |
bbox = obj['bndbox'] | |
objs.append({'label': label, | |
'label_id': label_id, | |
'bbox': bbox}) | |
return {'image': ann_data['filename'], | |
'objs': objs} | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--channel', required=False, | |
help='DATA LAKE CHANNEL ID') | |
parser.add_argument('--dataset', required=False, help='Dataset Name') | |
parser.add_argument('--annotation', required=False, | |
help='Directory Name of Annotation files') | |
parser.add_argument('--classes', required=False, | |
help='class definition file') | |
args = parser.parse_args() | |
dotenv_path = os.path.join(os.path.dirname(__file__), '.env') | |
load_dotenv(dotenv_path) | |
# ID | |
ORG_ID = os.environ.get("ORGANIZATION_ID") | |
channel_id = args.channel | |
classes = getClasses(args.classes) | |
labels = [] | |
for cat in classes: | |
labels.append({ | |
'label': cat, | |
'label_id': classes.index(cat) | |
}) | |
datasetname = args.dataset | |
category = { | |
'labels': labels, | |
'category_id': 0, | |
'name': datasetname | |
} | |
props = {'categories': [category]} | |
# Create DataSet | |
dataset_client = APIClient(credential) | |
dataset_list = dataset_client.list_datasets(organization_id=ORG_ID) | |
for ds in dataset_list: | |
if datasetname in ds['name']: | |
dataset_client.delete_dataset(ORG_ID, ds['dataset_id']) | |
break | |
dataset = dataset_client.create_dataset( | |
ORG_ID, datasetname, 'detection', props) | |
dataset_id = dataset['dataset_id'] | |
# Get Data List from Datalake | |
client = DatalakeClient( | |
organization_id=ORG_ID, credential=credential) | |
channel = client.get_channel(channel_id) | |
file_itr = channel.list_files() | |
anno_dir = args.annotation | |
anno_files = glob.glob(os.path.join(anno_dir, '*.xml')) | |
for f in file_itr: | |
f_info = f.get_file_info() | |
filename = f_info['metadata']['x-abeja-meta-filename'] | |
file_id = f_info['file_id'] | |
content_type = f_info['content_type'] | |
f_xml = os.path.join(anno_dir, '{}.xml'.format( | |
os.path.splitext(filename)[0])) | |
if f_xml in anno_files: | |
dict_anno = getBBoxData(anno_file=f_xml, classes=classes) | |
info = [] | |
for obj in dict_anno['objs']: | |
info.append({ | |
'category_id': 0, | |
'label': obj['label'], | |
'label_id': int(obj['label_id']), | |
'rect': { | |
'xmin': obj['bbox']['xmin'], | |
'ymin': obj['bbox']['ymin'], | |
'xmax': obj['bbox']['xmax'], | |
'ymax': obj['bbox']['ymax'], | |
}}) | |
else: | |
continue | |
# Create DatasetItem | |
data_uri = 'datalake://{}/{}'.format(channel_id, file_id) | |
source_data = [{'data_uri': data_uri, 'data_type': content_type}] | |
attributes = {'detection': info} | |
dataset_client.create_dataset_item( | |
ORG_ID, dataset_id, source_data, attributes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment