Skip to content

Instantly share code, notes, and snippets.

@nilakshdas
Created April 26, 2018 05:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nilakshdas/752e61b9048ea552a76e441f96b5d178 to your computer and use it in GitHub Desktop.
Save nilakshdas/752e61b9048ea552a76e441f96b5d178 to your computer and use it in GitHub Desktop.
`clipart_segmentation.py` generates segmentation maps for the Abstract Scenes dataset from https://vision.ece.vt.edu/clipart/
from collections import namedtuple
import json
import os
from addict import Dict
from matplotlib import pyplot as plt
import numpy as np
from scipy.misc import imread, imsave, imresize
from tqdm import tqdm
CONFIG = Dict()
CONFIG.SCENE_WIDTH = 500
CONFIG.SCENE_HEIGHT = 400
CONFIG.SPRITES_DIR = '/path/to/AbstractScenes_v1.1/Pngs/'
CONFIG.SCENES_PATH = '/path/to/AbstractScenes_v1.1/Scenes_10020.txt'
CONFIG.OUT_DIR = '/path/to/output/folder/'
CONFIG.SPRITE_ID_MAPPING = {
str(sprite_id): i
for i, sprite_id
in enumerate(sorted(
os.listdir(CONFIG.SPRITES_DIR)))
}
N = 256
cmap = np.zeros((N, 3), dtype=np.uint8)
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
for i in range(N):
r = g = b = 0; c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
def load_scenes():
Sprite = namedtuple('Sprite',
['id', 'x', 'y', 'z', 'flipped'])
scenes = list()
with open(CONFIG.SCENES_PATH, 'r') as f:
num_scenes = int(f.readline().strip())
for _1 in range(num_scenes):
scene_id, num_sprites = map(int, f.readline().strip().split('\t'))
scene = list()
for _2 in range(num_sprites):
sprite_data = f.readline().strip().split('\t')
sprite = Sprite(id=sprite_data[0],
x=int(sprite_data[3]),
y=int(sprite_data[4]),
z=int(sprite_data[5]),
flipped=(int(sprite_data[6]) == 1))
scene.append(sprite)
scenes.append(scene)
return scenes
def create_segmentation_map(scene):
segmentation_map = np.zeros(
(CONFIG.SCENE_HEIGHT, CONFIG.SCENE_WIDTH),
dtype=np.uint8)
for sprite in scene:
sprite_label = CONFIG.SPRITE_ID_MAPPING[sprite.id]
sprite_image = imread(os.path.join(CONFIG.SPRITES_DIR, sprite.id))
sprite_image = sprite_image.sum(axis=2)
sprite_image = imresize(sprite_image, [100, 70, 49][sprite.z])
sprite_image[sprite_image != 0] = sprite_label
if sprite.flipped:
sprite_image = np.fliplr(sprite_image)
h, w = sprite_image.shape
h_, w_ = int(h / 2), int(w / 2)
l_index = max(0, w_ - sprite.x)
t_index = max(0, h_ - sprite.y)
xmin, ymin = max(0, sprite.x - w_), \
max(0, sprite.y - h_)
xmax, ymax = min(xmin + w - l_index, CONFIG.SCENE_WIDTH), \
min(ymin + h - t_index, CONFIG.SCENE_HEIGHT)
for x in range(xmin, xmax):
for y in range(ymin, ymax):
if sprite_image[t_index + y - ymin, l_index + x - xmin] != 0:
segmentation_map[y, x] = sprite_image[t_index + y - ymin, l_index + x - xmin]
return segmentation_map
def create_segmentation_image(segmentation_map):
h, w = segmentation_map.shape
segmentation_image = np.zeros((h, w, 3), dtype=np.uint8)
for k in set(segmentation_map.reshape(-1)):
segmentation_image[np.where(segmentation_map == k)] = cmap[k]
return segmentation_image
def main():
scenes = load_scenes()
for i, scene in tqdm(enumerate(scenes), total=len(scenes)):
segmentation_map = create_segmentation_map(scene)
segmentation_image = create_segmentation_image(segmentation_map)
filename = 'Scene%d_0' % i
np.save(os.path.join(CONFIG.OUT_DIR, filename+'.npy'), segmentation_map, allow_pickle=False)
imsave(os.path.join(CONFIG.OUT_DIR, filename+'.png'), segmentation_image)
with open(os.path.join(CONFIG.OUT_DIR, 'labelmap.json'), 'w') as f:
f.write(json.dumps(CONFIG.SPRITE_ID_MAPPING))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment