Skip to content

Instantly share code, notes, and snippets.

@danoneata
Last active August 29, 2015 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danoneata/ba15809c0cb9287a2f38 to your computer and use it in GitHub Desktop.
Save danoneata/ba15809c0cb9287a2f38 to your computer and use it in GitHub Desktop.
Generates any segmentation from the base SLICs
"""Usage:
python aggregate_tree.py 500000
python aggregate_tree.py -3000
"""
from matplotlib import pyplot as plt
import numpy as np
import pdb
from scipy.misc import imsave
import sys
class DisjointSet(object):
"""The disjoint-set data structure.
See http://en.wikipedia.org/wiki/Disjoint-set_data_structure
"""
def __init__(self, n):
self.parents = np.arange(n)
self.ranks = np.zeros(n, dtype=np.int32)
def find_representative(self, x):
"""Find the representative of the set x belongs to."""
if self.parents[x] != x:
self.parents[x] = self.find_representative(self.parents[x])
return self.parents[x]
def same_set(self, x, y):
"""Check whether two elements are in the same set."""
x_root = self.find_representative(x)
y_root = self.find_representative(y)
return x_root == y_root
def union(self, x, y):
"""Union two sets. Returns 0 if x and y were already in the same set."""
x_root = self.find_representative(x)
y_root = self.find_representative(y)
if x_root == y_root: # Already in the same set
return 0
if self.ranks[x_root] < self.ranks[y_root]:
self.parents[x_root] = y_root
elif self.ranks[x_root] > self.ranks[y_root]:
self.parents[y_root] = x_root
else:
self.parents[y_root] = x_root
self.ranks[x_root] += 1
return 1
def aggregate_tree(tree, labels, level):
nr_nodes = len(tree) + 1
disjoint_set = DisjointSet(nr_nodes)
if level < 0:
level += len(tree)
for ii in xrange(level):
disjoint_set.union(tree[ii, 0], tree[ii, 1])
label_to_repr = np.zeros(nr_nodes, dtype=np.int32)
for ii in xrange(nr_nodes):
label_to_repr[ii] = disjoint_set.find_representative(ii)
return label_to_repr
def get_images(labels, label_to_repr, to_save=False):
rgb = np.random.rand(np.max(labels) + 1, 3)
for ii, frame in enumerate(labels):
rgb_frame = rgb[label_to_repr[frame.flatten()]].reshape(frame.shape[0], frame.shape[1], 3)
if to_save:
imsave('%06d.jpg' % ii, rgb_frame)
else:
plt.imshow(rgb_frame)
plt.show()
def main():
level = int(sys.argv[1])
hierarchy = np.load('001.npz') # wget http://pascal.inrialpes.fr/data2/oneata/data/msr_ii/hierarchy/001.npz
label_to_repr = aggregate_tree(hierarchy['tree'], hierarchy['labels'], level)
get_images(hierarchy['labels'], label_to_repr)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment