Skip to content

Instantly share code, notes, and snippets.

@mattroz
Created January 26, 2021 15:51
Show Gist options
  • Save mattroz/a2abb7e98bf8d20a0e83f0c306fecd2a to your computer and use it in GitHub Desktop.
Save mattroz/a2abb7e98bf8d20a0e83f0c306fecd2a to your computer and use it in GitHub Desktop.
This script allows to get a JSON file with specific categories instanses from COCO dataset.
# This code has been taken from https://github.com/immersive-limit/coco-manager/
# Usage: python filter.py --input_json /path/to/instances_train2017.json
# --output_json path_to/saved.json --categories person dog cat
import json
from pathlib import Path
class CocoFilter():
""" Filters the COCO dataset
"""
def _process_info(self):
self.info = self.coco['info']
def _process_licenses(self):
self.licenses = self.coco['licenses']
def _process_categories(self):
self.categories = dict()
self.super_categories = dict()
self.category_set = set()
for category in self.coco['categories']:
cat_id = category['id']
super_category = category['supercategory']
# Add category to categories dict
if cat_id not in self.categories:
self.categories[cat_id] = category
self.category_set.add(category['name'])
else:
print(f'ERROR: Skipping duplicate category id: {category}')
# Add category id to the super_categories dict
if super_category not in self.super_categories:
self.super_categories[super_category] = {cat_id}
else:
self.super_categories[super_category] |= {cat_id} # e.g. {1, 2, 3} |= {4} => {1, 2, 3, 4}
def _process_images(self):
self.images = dict()
for image in self.coco['images']:
image_id = image['id']
if image_id not in self.images:
self.images[image_id] = image
else:
print(f'ERROR: Skipping duplicate image id: {image}')
def _process_segmentations(self):
self.segmentations = dict()
for segmentation in self.coco['annotations']:
image_id = segmentation['image_id']
if image_id not in self.segmentations:
self.segmentations[image_id] = []
self.segmentations[image_id].append(segmentation)
def _filter_categories(self):
""" Find category ids matching args
Create mapping from original category id to new category id
Create new collection of categories
"""
missing_categories = set(self.filter_categories) - self.category_set
if len(missing_categories) > 0:
print(f'Did not find categories: {missing_categories}')
should_continue = input('Continue? (y/n) ').lower()
if should_continue != 'y' and should_continue != 'yes':
print('Quitting early.')
quit()
self.new_category_map = dict()
new_id = 1
for key, item in self.categories.items():
if item['name'] in self.filter_categories:
self.new_category_map[key] = new_id
new_id += 1
self.new_categories = []
for original_cat_id, new_id in self.new_category_map.items():
new_category = dict(self.categories[original_cat_id])
new_category['id'] = new_id
self.new_categories.append(new_category)
def _filter_annotations(self):
""" Create new collection of annotations matching category ids
Keep track of image ids matching annotations
"""
self.new_segmentations = []
self.new_image_ids = set()
for image_id, segmentation_list in self.segmentations.items():
for segmentation in segmentation_list:
original_seg_cat = segmentation['category_id']
if original_seg_cat in self.new_category_map.keys():
new_segmentation = dict(segmentation)
new_segmentation['category_id'] = self.new_category_map[original_seg_cat]
self.new_segmentations.append(new_segmentation)
self.new_image_ids.add(image_id)
def _filter_images(self):
""" Create new collection of images
"""
self.new_images = []
for image_id in self.new_image_ids:
self.new_images.append(self.images[image_id])
def main(self, args):
# Open json
self.input_json_path = Path(args.input_json)
self.output_json_path = Path(args.output_json)
self.filter_categories = args.categories
# Verify input path exists
if not self.input_json_path.exists():
print('Input json path not found.')
print('Quitting early.')
quit()
# Verify output path does not already exist
if self.output_json_path.exists():
should_continue = input('Output path already exists. Overwrite? (y/n) ').lower()
if should_continue != 'y' and should_continue != 'yes':
print('Quitting early.')
quit()
# Load the json
print('Loading json file...')
with open(self.input_json_path) as json_file:
self.coco = json.load(json_file)
# Process the json
print('Processing input json...')
self._process_info()
self._process_licenses()
self._process_categories()
self._process_images()
self._process_segmentations()
# Filter to specific categories
print('Filtering...')
self._filter_categories()
self._filter_annotations()
self._filter_images()
# Build new JSON
new_master_json = {
'info': self.info,
'licenses': self.licenses,
'images': self.new_images,
'annotations': self.new_segmentations,
'categories': self.new_categories
}
# Write the JSON to a file
print('Saving new json file...')
with open(self.output_json_path, 'w+') as output_file:
json.dump(new_master_json, output_file)
print('Filtered json saved.')
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Filter COCO JSON: "
"Filters a COCO Instances JSON file to only include specified categories. "
"This includes images, and annotations. Does not modify 'info' or 'licenses'.")
parser.add_argument("-i", "--input_json", dest="input_json",
help="path to a json file in coco format")
parser.add_argument("-o", "--output_json", dest="output_json",
help="path to save the output json")
parser.add_argument("-c", "--categories", nargs='+', dest="categories",
help="List of category names separated by spaces, e.g. -c person dog bicycle")
args = parser.parse_args()
cf = CocoFilter()
cf.main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment