Skip to content

Instantly share code, notes, and snippets.

@Jacajack
Last active April 25, 2023 16:40
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Jacajack/1f40006289c54d9d644cc04e4f56c43b to your computer and use it in GitHub Desktop.
Save Jacajack/1f40006289c54d9d644cc04e4f56c43b to your computer and use it in GitHub Desktop.
Image mask generator for uni project
#!/usr/bin/python3
import sys
import os
import pathlib
import json
import cv2
import numpy as np
CATEGORY_COLORS = {
1: (255, 0, 0), # NonMaskingBackground
2: (0, 255, 0), # MaskingBackground
3: (0, 0, 255), # Animal
4: (255, 255, 255), # NonMaskingForegroundAttention
}
written_paths = {}
def get_coco_paths(input_dir):
coco_paths = []
for path in os.listdir(input_dir):
full_path = os.path.join(input_dir, path)
if os.path.isfile(full_path) and pathlib.Path(full_path).suffix == ".json":
if "coco" in full_path.lower():
coco_paths.append(full_path)
print(f"Found {len(coco_paths)} COCO files...")
return coco_paths
def load_cocos(coco_paths):
cocos = []
for coco_path in coco_paths:
with open(coco_path, "r") as f:
coco = json.load(f)
coco["path"] = coco_path
cocos.append(coco)
print(f"Loaded {len(cocos)} COCO files...")
return cocos
def process_cocos(cocos, input_dir, output_dir):
broken_cocos = 0
failed_cocos = 0
failed_images = 0
total_images = 0
for coco in cocos:
print(f"Loading {coco['path']}...")
try:
valid_coco = True
for image_info in coco['images']:
total_images = total_images + 1
image_path = os.path.join(input_dir, image_info['file_name'])
image = None
try:
image = cv2.imdecode(np.fromfile(image_path, dtype = np.uint8), -1)
except KeyboardInterrupt:
pass
except:
print(f"Failed to load '{image_path}' - skipping...")
failed_images = failed_images + 1
valid_coco = False
else:
print(f"Processing {image_path}...")
process_coco_img(coco, image_info, image, input_dir, output_dir)
if not valid_coco:
broken_cocos = broken_cocos + 1
except KeyError:
print(f"Somebody doesn't know what COCO is - {coco['path']}")
failed_cocos = failed_cocos + 1
total_cocos = len(cocos)
valid_cocos = total_cocos - failed_cocos - broken_cocos
print(f"Loaded {total_images - failed_images}/{total_images} images.")
print(f"{valid_cocos}/{total_cocos} COCOs valid ({failed_cocos} are completely broken, {broken_cocos} refernce bad images)")
def to_points(l):
return [l[i:i+2] for i in range(0, len(l), 2)]
def process_coco_img(coco, image_info, image, input_dir, output_dir):
output_path = os.path.splitext(os.path.join(output_dir, image_info['file_name']))[0] + ".png"
mask_data = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
warned_about_mt = False
valid_segments = 0
if image_info['width'] != image.shape[1]:
print(f"Image width does not match COCO (COCO {image_info['width']} vs {image.shape[1]}) - skipping")
return
if image_info['height'] != image.shape[0]:
print(f"Image height does not match COCO (COCO {image_info['height']} vs {image.shape[0]}) - skipping")
return
for anno in coco['annotations']:
image_id = int(anno['image_id'])
category_id = int(anno['category_id'])
if image_id != image_info['id']:
continue
if category_id not in CATEGORY_COLORS:
print(f"Category ID {anno['category_id']} in {coco['path']} is not known - skipping annotation!")
continue
color = tuple(reversed(CATEGORY_COLORS[category_id]))
def process_segment(segment):
nonlocal valid_segments, mask_data, color, anno
valid_segments = valid_segments + 1
points = to_points(segment)
point_data = np.array(points, dtype = np.int32)
# print(point_data)
try:
cv2.fillPoly(mask_data, [point_data], color)
except KeyboardInterrupt:
pass
except:
print(f"Failed filling polygon - annotation ID {anno['id']}")
pass
# Thank you, M.T, for making me write this cursed code
if hasattr(anno['segmentation'][0], "__len__"):
for segment in anno['segmentation']:
process_segment(segment)
else:
if not warned_about_mt:
print(f"Warning: {coco['path']} is invalid, but will attempt to read (it's in M.T format)")
warned_about_mt = True
process_segment(anno['segmentation'])
if valid_segments == 0:
print(f"{image_info['file_name']} doesn't have any valid annotations!")
try:
# print(f"Writing {output_path}...")
if output_path in written_paths:
previous = written_paths[output_path]
print(f"Warning: {output_path} has already been written this time! (ref. by {previous['coco_path']}, now reading {coco['path']})")
else:
written_paths[output_path] = {'coco_path': coco['path']}
success = cv2.imwrite(output_path, mask_data)
if not success:
print(f"Warning: OpenCV imwrite() probably errored on {output_path}")
except KeyboardInterrupt:
pass
except:
print(f"OpenCV error when writing {output_path}")
if len(sys.argv) < 3:
print(f"Usage: {sys.argv[0]} <INPUT_DIR> <OUTPUT_DIR>")
exit(1)
input_dir = sys.argv[1]
output_dir = sys.argv[2]
print(f"Input dir: {input_dir}")
print(f"Output dir: {output_dir}")
if not os.path.exists(output_dir):
try:
print(f"Creating output directory!")
os.mkdir(output_dir)
except:
print(f"Failed to create output dir {output_dir}")
exit(1)
process_cocos(load_cocos(get_coco_paths(input_dir)), input_dir, output_dir)
print(f"Written {len(written_paths)} unique image paths.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment