Extract labeled field from cvat annotation xml. / CVATのアノテーション情報から画像を取り出す
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Install numpy opencv-python lxml tqdm joblib | |
import argparse | |
import pathlib | |
import cv2 | |
import numpy as np | |
from joblib import Parallel, delayed | |
from lxml import etree | |
def _parse_anno_file(root: etree.ElementTree, image_name): | |
annotations = [] | |
image_name_attr = ".//image[@name='{}']".format(image_name) | |
for image_tag in root.iterfind(image_name_attr): | |
image = {} | |
for key, value in image_tag.items(): | |
image[key] = value | |
image['shapes'] = [] | |
for poly_tag in image_tag.iter('polygon'): | |
polygon = {'type': 'polygon'} | |
for key, value in poly_tag.items(): | |
polygon[key] = value | |
image['shapes'].append(polygon) | |
for box_tag in image_tag.iter('box'): | |
box = {'type': 'box'} | |
for key, value in box_tag.items(): | |
box[key] = value | |
box['points'] = "{0},{1};{2},{1};{2},{3};{0},{3}".format( | |
box['xtl'], box['ytl'], box['xbr'], box['ybr']) | |
image['shapes'].append(box) | |
image['shapes'].sort(key=lambda x: int(x.get('z_order', 0))) | |
annotations.append(image) | |
return annotations | |
def _create_mask_file(width, height, background, shapes, scale_factor=1.0): | |
mask = np.full((height, width, 3), background, dtype=np.uint8) | |
for shape in shapes: | |
points = [tuple(map(float, p.split(','))) for p in shape['points'].split(';')] | |
points = np.array([(int(p[0]), int(p[1])) for p in points]) | |
points = points * scale_factor | |
points = points.astype(int) | |
mask = cv2.drawContours(mask, [points], -1, color=(255, 255, 255), thickness=0) | |
mask = cv2.fillPoly(mask, [points], color=(255, 255, 255)) | |
return mask | |
def _save_mask_image(cvat_xml: str, image: pathlib.Path, output: pathlib.Path): | |
root = etree.parse(cvat_xml).getroot() | |
# Skip non target file type | |
if image.suffix not in ['.png', '.jpg']: | |
return | |
annotations = _parse_anno_file(root, image.name) | |
base_image = cv2.imread(str(image), -1) | |
height, width, _ = base_image.shape | |
background = np.zeros((height, width, 3), np.uint8) | |
for annotation in annotations: | |
background = _create_mask_file(width, height, | |
background, annotation['shapes']) | |
# Skip no mask image | |
if background.sum() == 0: | |
return | |
ypath = output / (image.stem + '.png') | |
png_image = cv2.cvtColor(base_image, cv2.COLOR_RGB2RGBA) | |
png_image[:, :, 3] = background[:, :, 0] | |
cv2.imwrite(str(ypath), png_image) | |
def create_dataset(image_dir: str, cvat_xml: str, output_dir: str): | |
output = pathlib.Path(output_dir) | |
output.mkdir(exist_ok=True, parents=True) | |
Parallel(n_jobs=4)(delayed(_save_mask_image)(cvat_xml, image, output) for image in | |
pathlib.Path(image_dir).iterdir()) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image_dir', type=str, help='images direvtory') | |
parser.add_argument('--cvat_xml', type=str, help='cvat annotation xml') | |
parser.add_argument('--output_dir', type=str, help='path for mask images') | |
args = parser.parse_args() | |
create_dataset(args.image_dir, args.cvat_xml, args.output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment