Skip to content

Instantly share code, notes, and snippets.

@psycharo-zz
Created March 1, 2017 17:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save psycharo-zz/59d9625c89d2f7881a7c6f0152b4182e to your computer and use it in GitHub Desktop.
Save psycharo-zz/59d9625c89d2f7881a7c6f0152b4182e to your computer and use it in GitHub Desktop.
custom queue runner to read cityscapes
def instance_to_regression_map(instances, cids):
"""Convert instance label map to the regression map
Args:
instances: instance label mask
cids: ids of classes to load
"""
# TODO: for all the classes that have instances, we can compute this
image_size = instances.shape[:2]
reg = np.zeros(image_size + (4,), dtype=np.uint16)
# instead of this, we can simply ???
mask = np.zeros(image_size, dtype=np.bool)
for cid in cids:
mask |= (instances >= cid * 1000) & (instances < (cid+1) * 1000)
instance_ids = np.unique(instances[mask])
for iid in instance_ids:
y, x = np.where(instances == iid)
reg[y,x,0] = y - np.min(y)
reg[y,x,1] = x - np.min(x)
reg[y,x,2] = np.max(y) - y
reg[y,x,3] = np.max(x) - x
return reg
def read_example(rgb_fname, json_fname, void_train_id=19):
INSTANCE_CIDS = np.arange(11, 19)
rgb = cv2.imread(rgb_fname)[:,:,::-1]
anno = json2instanceImg.Annotation()
anno.fromJsonFile(json_fname)
instances_pil = json2instanceImg.createInstanceImage(anno, 'trainIds')
instances = np.fromstring(instances_pil.tobytes(), dtype=np.int32)
instances = instances.reshape((anno.imgHeight, anno.imgWidth))
seg_pil = json2labelImg.createLabelImage(anno, 'trainIds')
seg = np.fromstring(seg_pil.tobytes(), dtype=np.uint8)
seg = seg.reshape((anno.imgHeight, anno.imgWidth))
seg[seg == 255] = void_train_id
reg = instance_to_regression_map(instances, INSTANCE_CIDS)
return rgb, seg, reg
class CityscapesRunner(object):
def __init__(self, filenames, src_size, num_threads, capacity=128):
self.filenames = filenames
self.num_threads = num_threads
self.lock = threading.Lock()
self.step = 0
self.rgb = tf.placeholder(tf.uint8, [src_size[0], src_size[1], 3])
self.seg = tf.placeholder(tf.uint8, [src_size[0], src_size[1]])
self.reg = tf.placeholder(tf.uint16, [src_size[0], src_size[1], 4])
self.queue = tf.FIFOQueue(capacity=capacity,
dtypes=[tf.uint8, tf.uint8, tf.uint16],
shapes=[[src_size[0], src_size[1], 3],
[src_size[0], src_size[1]],
[src_size[0], src_size[1], 4]])
self.enqueue_op = self.queue.enqueue([self.rgb, self.seg, self.reg])
def _data_iterator(self):
while True:
with self.lock:
if self.step == len(self.filenames):
break
idx = self.step
self.step += 1
rgb_fname, json_fname = self.filenames[idx]
yield read_example(rgb_fname, json_fname)
def _run(self, sess, coord):
try:
for rgb, seg, reg in self._data_iterator():
if coord and coord.should_stop():
break
feed_dict = {
self.rgb : rgb,
self.seg : seg,
self.reg : reg
}
sess.run(self.enqueue_op, feed_dict)
except Exception as e:
if coord:
coord.request_stop(e)
def inputs(self):
return self.queue.dequeue()
def create_threads(self, sess, coord=None, daemon=False, start=False):
threads = [threading.Thread(target=self._run, args=(sess, coord,))
for i in range(self.num_threads)]
for t in threads:
t.daemon = daemon
if start:
t.start()
if coord:
coord.register_thread(t)
return threads
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment