Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Extract labeled field from cvat annotation xml. / CVATのアノテーション情報から画像を取り出す
# Install numpy opencv-python lxml tqdm joblib
import argparse
import pathlib
import cv2
import numpy as np
from joblib import Parallel, delayed
from lxml import etree
def _parse_anno_file(root: etree.ElementTree, image_name):
annotations = []
image_name_attr = ".//image[@name='{}']".format(image_name)
for image_tag in root.iterfind(image_name_attr):
image = {}
for key, value in image_tag.items():
image[key] = value
image['shapes'] = []
for poly_tag in image_tag.iter('polygon'):
polygon = {'type': 'polygon'}
for key, value in poly_tag.items():
polygon[key] = value
image['shapes'].append(polygon)
for box_tag in image_tag.iter('box'):
box = {'type': 'box'}
for key, value in box_tag.items():
box[key] = value
box['points'] = "{0},{1};{2},{1};{2},{3};{0},{3}".format(
box['xtl'], box['ytl'], box['xbr'], box['ybr'])
image['shapes'].append(box)
image['shapes'].sort(key=lambda x: int(x.get('z_order', 0)))
annotations.append(image)
return annotations
def _create_mask_file(width, height, background, shapes, scale_factor=1.0):
mask = np.full((height, width, 3), background, dtype=np.uint8)
for shape in shapes:
points = [tuple(map(float, p.split(','))) for p in shape['points'].split(';')]
points = np.array([(int(p[0]), int(p[1])) for p in points])
points = points * scale_factor
points = points.astype(int)
mask = cv2.drawContours(mask, [points], -1, color=(255, 255, 255), thickness=0)
mask = cv2.fillPoly(mask, [points], color=(255, 255, 255))
return mask
def _save_mask_image(cvat_xml: str, image: pathlib.Path, output: pathlib.Path):
root = etree.parse(cvat_xml).getroot()
# Skip non target file type
if image.suffix not in ['.png', '.jpg']:
return
annotations = _parse_anno_file(root, image.name)
base_image = cv2.imread(str(image), -1)
height, width, _ = base_image.shape
background = np.zeros((height, width, 3), np.uint8)
for annotation in annotations:
background = _create_mask_file(width, height,
background, annotation['shapes'])
# Skip no mask image
if background.sum() == 0:
return
ypath = output / (image.stem + '.png')
png_image = cv2.cvtColor(base_image, cv2.COLOR_RGB2RGBA)
png_image[:, :, 3] = background[:, :, 0]
cv2.imwrite(str(ypath), png_image)
def create_dataset(image_dir: str, cvat_xml: str, output_dir: str):
output = pathlib.Path(output_dir)
output.mkdir(exist_ok=True, parents=True)
Parallel(n_jobs=4)(delayed(_save_mask_image)(cvat_xml, image, output) for image in
pathlib.Path(image_dir).iterdir())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, help='images direvtory')
parser.add_argument('--cvat_xml', type=str, help='cvat annotation xml')
parser.add_argument('--output_dir', type=str, help='path for mask images')
args = parser.parse_args()
create_dataset(args.image_dir, args.cvat_xml, args.output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment