Last active
November 4, 2021 14:39
-
-
Save johnPertoft/4b909fd099b60df01a041cd98f17a1dc to your computer and use it in GitHub Desktop.
(Slow) Tensorflow implementation of Locality Aware NMS as described in https://arxiv.org/abs/1704.03155
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def locality_aware_nms(vertices, probs, iou_threshold, name=None): | |
""" | |
Implements the Locality-Aware NMS algorithm as described in EAST (https://arxiv.org/abs/1704.03155). | |
Args: | |
vertices: 3-D tensor of shape (?, 4, 2) representing the unmerged boxes. | |
probs: 2-D tensor of shape (?, 1) representing the unmerged boxes' probabilities. | |
iou_threshold: Intersection over union threshold for when to merge boxes. | |
name: Optional name of this operation. | |
Returns: | |
vertices: 3-D tensor of shape (?, 4, 2) representing the merged boxes. | |
scores: 1-D tensor of shape (?,) representing the merged boxes' scores. | |
""" | |
with tf.name_scope(name, 'locality_aware_nms'): | |
assert vertices.shape[1:] == (4, 2), 'Assumes bounding boxes given as (4, 2).' | |
# Sort bounding boxes row wise by sorting based on top coordinate of each bounding box. | |
vertices_sort_key = tf.reduce_min(vertices[..., 1], axis=1) | |
sorted_indices = tf.argsort(vertices_sort_key) | |
vertices = tf.gather(vertices, sorted_indices) | |
probs = tf.gather(probs, sorted_indices) | |
vertices_with_probs = tf.concat((tf.reshape(vertices, (-1, 8)), probs), axis=1) | |
def should_merge(p, g): | |
"""Compute whether two rectangles should merge.""" | |
p = p[:, :8] | |
p = tf.reshape(p, (4, 2)) | |
g = g[:, :8] | |
g = tf.reshape(g, (4, 2)) | |
return intersection_over_union(p, g) >= iou_threshold | |
def weighted_merge(p, g): | |
"""Merge two rectangles.""" | |
p_score = p[:, 8] | |
p_vertices = p[:, :8] | |
g_score = g[:, 8] | |
g_vertices = g[:, :8] | |
new_score = p_score + g_score | |
new_score = new_score[:, tf.newaxis] | |
new_vertices = (p_score * p_vertices + g_score * g_vertices) / (p_score + g_score) | |
return tf.concat((new_vertices, new_score), axis=1) | |
def locality_aware_nms_step(S, p, g): | |
g = g[tf.newaxis, ...] | |
p_is_none = tf.equal(tf.shape(p)[0], 0) | |
def p_is_none_branch(): | |
# Set p to g. | |
return S, g | |
def p_is_not_none_branch(): | |
def should_merge_branch(): | |
# Set p to weighted merge of p and g. | |
return S, weighted_merge(p, g) | |
def should_not_merge_branch(): | |
# Add p to S. | |
# Set p to g. | |
return tf.concat((S, p), axis=0), g | |
return tf.cond(should_merge(p, g), true_fn=should_merge_branch, false_fn=should_not_merge_branch) | |
return tf.cond(p_is_none, true_fn=p_is_none_branch, false_fn=p_is_not_none_branch) | |
S = forget_first_dimension(tf.zeros((0, 9))) | |
p = forget_first_dimension(tf.zeros((0, 9))) | |
# Fold over all rectangles. | |
S, p = tf.foldl( | |
lambda acc, curr: locality_aware_nms_step(*acc, curr), | |
vertices_with_probs, | |
initializer=(S, p), | |
parallel_iterations=1) | |
# No effect if p is empty. | |
S = tf.concat((S, p), axis=0) | |
S = _standard_nms(S, iou_threshold) | |
vertices = S[..., :8] | |
vertices = tf.reshape(vertices, (-1, 4, 2)) | |
scores = S[..., 8] | |
return vertices, scores | |
def _standard_nms(S, iou_threshold): | |
"""Implements the standard NMS algorithm.""" | |
S_empty = tf.equal(tf.shape(S)[0], 0) | |
def S_empty_branch(): | |
return S | |
def S_not_empty_branch(): | |
def nms_while_condition(_, candidates): | |
return tf.shape(candidates)[0] > 0 | |
def nms_step(keep, candidates): | |
g = candidates[0] | |
g = g[tf.newaxis, :] | |
keep = tf.concat((keep, g), axis=0) | |
candidates = candidates[1:] | |
g = tf.reshape(g[:, :8], (4, 2)) | |
def g_iou(x): | |
x = x[:8] | |
x = tf.reshape(x, (4, 2)) | |
return intersection_over_union(g, x) | |
ious = tf.map_fn(g_iou, candidates) | |
candidates = tf.gather(candidates, tf.where(ious <= iou_threshold)[:, 0]) | |
return keep, candidates | |
# Sort S based on scores. | |
sort_key = S[:, 8] | |
sorted_indices = tf.argsort(sort_key, direction='DESCENDING') | |
sorted_S = tf.gather(S, sorted_indices) | |
keep = forget_first_dimension(tf.zeros((0, 9))) | |
candidates = sorted_S | |
keep, _ = tf.while_loop( | |
cond=nms_while_condition, | |
body=nms_step, | |
loop_vars=(keep, candidates), | |
parallel_iterations=1) | |
return keep | |
S = tf.cond(S_empty, true_fn=S_empty_branch, false_fn=S_not_empty_branch) | |
return S | |
def polygon_area(vertices, n_corners, name=None): | |
""" | |
Compute the area of a convex polygon. Implements the algorithm described at | |
http://mathworld.wolfram.com/PolygonArea.html | |
Args: | |
vertices: 2-D tensor of shape (?, 2) where the last dimension is the xy coordinate. | |
n_corners: Number of vertices to consider. | |
name: Optional name of this operation. | |
Returns: | |
area: The computed polygon area. | |
""" | |
with tf.name_scope(name, 'polygon_area'): | |
vertices = vertices[:n_corners] | |
left = vertices | |
right = tf.roll(vertices, shift=-1, axis=0) | |
x_left = left[:, 0] | |
x_right = right[:, 0] | |
y_left = left[:, 1] | |
y_right = right[:, 1] | |
area = tf.reduce_sum(x_left * y_right - x_right * y_left) | |
area = area / 2 | |
area = tf.abs(area) | |
return area | |
def rectangles_intersection_area(subject_rectangle, clip_rectangle, name=None): | |
""" | |
Compute the intersection area between two rectangles. Implements the Sutherland-Hodgman algorithm | |
described at https://en.wikipedia.org/wiki/Sutherland%E2%80%93Hodgman_algorithm | |
Args: | |
subject_rectangle: 2-D tensor of shape (4, 2) representing the four vertices of the subject rectangle | |
defined in clockwise order starting at the top left vertex. | |
clip_rectangle: 2-D tensor of shape (4, 2) representing the four vertices of the clip rectangle | |
defined in clockwise order starting at the top left vertex. | |
name: Optional name of this operation. | |
Returns: | |
area: The computed area of the intersection polygon. | |
""" | |
# TODO: See if this can be made faster by considering that we only have rectangles. | |
# Sutherland-Hodgman is a general algorithm for convex polygons. | |
with tf.name_scope(name, 'rectangles_intersection_area'): | |
def inside(p, edge): | |
"""Compute whether p is inside (right of) edge.""" | |
p_x = p[0, 0] | |
p_y = p[0, 1] | |
e1_x = edge[0, 0] | |
e1_y = edge[0, 1] | |
e2_x = edge[1, 0] | |
e2_y = edge[1, 1] | |
return (e2_x - e1_x) * (p_y - e1_y) > (e2_y - e1_y) * (p_x - e1_x) | |
def compute_intersection(p1, p2, edge): | |
"""Compute the intersection point of a line segment and an infinite edge.""" | |
p1_x = p1[0, 0] | |
p1_y = p1[0, 1] | |
p2_x = p2[0, 0] | |
p2_y = p2[0, 1] | |
e1_x = edge[0, 0] | |
e1_y = edge[0, 1] | |
e2_x = edge[1, 0] | |
e2_y = edge[1, 1] | |
dc = [e1_x - e2_x, e1_y - e2_y] | |
dp = [p2_x - p1_x, p2_y - p1_y] | |
n1 = e1_x * e2_y - e1_y * e2_x | |
n2 = p2_x * p1_y - p2_y * p1_x | |
n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) | |
p = [(n1 * dp[0] - n2 * dc[0]) * n3, (n1 * dp[1] - n2 * dc[1]) * n3] | |
p = tf.stack(p, axis=0) | |
p = p[tf.newaxis, :] | |
return p | |
def clip_edge_loop_step(vertices, clip_edge): | |
"""Applied for each clip edge.""" | |
def point_loop_step(output_vertices, s, e): | |
"""Applied at each point of the current vertex list.""" | |
e = e[tf.newaxis, :] | |
def e_inside_clip_edge_branch(): | |
output = output_vertices | |
def s_not_inside_clip_edge_branch(): | |
return tf.concat((output, compute_intersection(s, e, clip_edge)), axis=0) | |
output = tf.cond( | |
tf.math.logical_not(inside(s, clip_edge)), | |
true_fn=s_not_inside_clip_edge_branch, | |
false_fn=lambda: output) | |
output = tf.concat((output, e), axis=0) | |
return output | |
def e_not_inside_clip_edge_branch(): | |
output = output_vertices | |
def s_inside_clip_edge_branch(): | |
return tf.concat((output, compute_intersection(s, e, clip_edge)), axis=0) | |
output = tf.cond( | |
inside(s, clip_edge), | |
true_fn=s_inside_clip_edge_branch, | |
false_fn=lambda: output) | |
return output | |
output_vertices = tf.cond( | |
inside(e, clip_edge), | |
true_fn=e_inside_clip_edge_branch, | |
false_fn=e_not_inside_clip_edge_branch) | |
return output_vertices, e | |
input_vertices = vertices | |
input_vertices_empty = tf.equal(tf.shape(input_vertices)[0], 0) | |
def input_vertices_empty_branch(): | |
return input_vertices | |
def input_vertices_not_empty_branch(): | |
s = input_vertices[-1] | |
s = s[tf.newaxis, :] | |
# Start with empty list of output vertices. | |
output_vertices = forget_first_dimension(tf.zeros((0, 2))) | |
output_vertices, _ = tf.foldl( | |
lambda acc, curr: point_loop_step(*acc, curr), | |
input_vertices, | |
initializer=(output_vertices, s), | |
parallel_iterations=1) | |
return output_vertices | |
output_vertices = tf.cond( | |
input_vertices_empty, | |
true_fn=input_vertices_empty_branch, | |
false_fn=input_vertices_not_empty_branch) | |
return output_vertices | |
clip_edges = tf.stack((clip_rectangle, tf.roll(clip_rectangle, shift=-1, axis=0)), axis=1) | |
initial_vertices = forget_first_dimension(subject_rectangle) | |
vertices = tf.foldl( | |
clip_edge_loop_step, | |
clip_edges, | |
initializer=initial_vertices, | |
parallel_iterations=1) | |
n_corners = tf.shape(vertices)[0] | |
area = polygon_area(vertices, n_corners) | |
return area | |
def intersection_over_union(a, b): | |
""" | |
Compute the intersection over union between two rectangles. | |
Args: | |
a: 2-D tensor of shape (4, 2) representing the four vertices of the first rectangle | |
defined in clockwise order starting at the top left vertex. | |
b: 2-D tensor of shape (4, 2) representing the four vertices of the second rectangle | |
defined in clockwise order starting at the top left vertex. | |
Returns: | |
iou: The ratio of intersection area over union area. | |
""" | |
a = tf.cast(a, tf.float32) | |
b = tf.cast(b, tf.float32) | |
# TODO: Add fast check if they intersect before computing areas. | |
def rectangle_area(r): | |
return polygon_area(r, 4) | |
intersection = rectangles_intersection_area(a, b) | |
union = rectangle_area(a) + rectangle_area(b) - intersection | |
return intersection / union | |
def forget_first_dimension(x): | |
""" | |
Utility function to "forget" the first dimension of a tensor. | |
Args: | |
x: N-D tensor. | |
Returns: | |
x: The same tensor but with the first dimension forgotten. | |
""" | |
g = x.graph | |
s = x.shape.as_list()[1:] | |
x = tf.placeholder_with_default(x, [None] + s) | |
g.prevent_feeding(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment