Created
April 26, 2018 05:14
-
-
Save nilakshdas/752e61b9048ea552a76e441f96b5d178 to your computer and use it in GitHub Desktop.
`clipart_segmentation.py` generates segmentation maps for the Abstract Scenes dataset from https://vision.ece.vt.edu/clipart/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
import json | |
import os | |
from addict import Dict | |
from matplotlib import pyplot as plt | |
import numpy as np | |
from scipy.misc import imread, imsave, imresize | |
from tqdm import tqdm | |
CONFIG = Dict() | |
CONFIG.SCENE_WIDTH = 500 | |
CONFIG.SCENE_HEIGHT = 400 | |
CONFIG.SPRITES_DIR = '/path/to/AbstractScenes_v1.1/Pngs/' | |
CONFIG.SCENES_PATH = '/path/to/AbstractScenes_v1.1/Scenes_10020.txt' | |
CONFIG.OUT_DIR = '/path/to/output/folder/' | |
CONFIG.SPRITE_ID_MAPPING = { | |
str(sprite_id): i | |
for i, sprite_id | |
in enumerate(sorted( | |
os.listdir(CONFIG.SPRITES_DIR))) | |
} | |
N = 256 | |
cmap = np.zeros((N, 3), dtype=np.uint8) | |
def bitget(byteval, idx): | |
return ((byteval & (1 << idx)) != 0) | |
for i in range(N): | |
r = g = b = 0; c = i | |
for j in range(8): | |
r = r | (bitget(c, 0) << 7-j) | |
g = g | (bitget(c, 1) << 7-j) | |
b = b | (bitget(c, 2) << 7-j) | |
c = c >> 3 | |
cmap[i] = np.array([r, g, b]) | |
def load_scenes(): | |
Sprite = namedtuple('Sprite', | |
['id', 'x', 'y', 'z', 'flipped']) | |
scenes = list() | |
with open(CONFIG.SCENES_PATH, 'r') as f: | |
num_scenes = int(f.readline().strip()) | |
for _1 in range(num_scenes): | |
scene_id, num_sprites = map(int, f.readline().strip().split('\t')) | |
scene = list() | |
for _2 in range(num_sprites): | |
sprite_data = f.readline().strip().split('\t') | |
sprite = Sprite(id=sprite_data[0], | |
x=int(sprite_data[3]), | |
y=int(sprite_data[4]), | |
z=int(sprite_data[5]), | |
flipped=(int(sprite_data[6]) == 1)) | |
scene.append(sprite) | |
scenes.append(scene) | |
return scenes | |
def create_segmentation_map(scene): | |
segmentation_map = np.zeros( | |
(CONFIG.SCENE_HEIGHT, CONFIG.SCENE_WIDTH), | |
dtype=np.uint8) | |
for sprite in scene: | |
sprite_label = CONFIG.SPRITE_ID_MAPPING[sprite.id] | |
sprite_image = imread(os.path.join(CONFIG.SPRITES_DIR, sprite.id)) | |
sprite_image = sprite_image.sum(axis=2) | |
sprite_image = imresize(sprite_image, [100, 70, 49][sprite.z]) | |
sprite_image[sprite_image != 0] = sprite_label | |
if sprite.flipped: | |
sprite_image = np.fliplr(sprite_image) | |
h, w = sprite_image.shape | |
h_, w_ = int(h / 2), int(w / 2) | |
l_index = max(0, w_ - sprite.x) | |
t_index = max(0, h_ - sprite.y) | |
xmin, ymin = max(0, sprite.x - w_), \ | |
max(0, sprite.y - h_) | |
xmax, ymax = min(xmin + w - l_index, CONFIG.SCENE_WIDTH), \ | |
min(ymin + h - t_index, CONFIG.SCENE_HEIGHT) | |
for x in range(xmin, xmax): | |
for y in range(ymin, ymax): | |
if sprite_image[t_index + y - ymin, l_index + x - xmin] != 0: | |
segmentation_map[y, x] = sprite_image[t_index + y - ymin, l_index + x - xmin] | |
return segmentation_map | |
def create_segmentation_image(segmentation_map): | |
h, w = segmentation_map.shape | |
segmentation_image = np.zeros((h, w, 3), dtype=np.uint8) | |
for k in set(segmentation_map.reshape(-1)): | |
segmentation_image[np.where(segmentation_map == k)] = cmap[k] | |
return segmentation_image | |
def main(): | |
scenes = load_scenes() | |
for i, scene in tqdm(enumerate(scenes), total=len(scenes)): | |
segmentation_map = create_segmentation_map(scene) | |
segmentation_image = create_segmentation_image(segmentation_map) | |
filename = 'Scene%d_0' % i | |
np.save(os.path.join(CONFIG.OUT_DIR, filename+'.npy'), segmentation_map, allow_pickle=False) | |
imsave(os.path.join(CONFIG.OUT_DIR, filename+'.png'), segmentation_image) | |
with open(os.path.join(CONFIG.OUT_DIR, 'labelmap.json'), 'w') as f: | |
f.write(json.dumps(CONFIG.SPRITE_ID_MAPPING)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment