Skip to content

Instantly share code, notes, and snippets.

@johnPertoft
Last active November 4, 2021 14:39
Show Gist options
  • Save johnPertoft/4b909fd099b60df01a041cd98f17a1dc to your computer and use it in GitHub Desktop.
Save johnPertoft/4b909fd099b60df01a041cd98f17a1dc to your computer and use it in GitHub Desktop.
(Slow) Tensorflow implementation of Locality Aware NMS as described in https://arxiv.org/abs/1704.03155
import tensorflow as tf
def locality_aware_nms(vertices, probs, iou_threshold, name=None):
"""
Implements the Locality-Aware NMS algorithm as described in EAST (https://arxiv.org/abs/1704.03155).
Args:
vertices: 3-D tensor of shape (?, 4, 2) representing the unmerged boxes.
probs: 2-D tensor of shape (?, 1) representing the unmerged boxes' probabilities.
iou_threshold: Intersection over union threshold for when to merge boxes.
name: Optional name of this operation.
Returns:
vertices: 3-D tensor of shape (?, 4, 2) representing the merged boxes.
scores: 1-D tensor of shape (?,) representing the merged boxes' scores.
"""
with tf.name_scope(name, 'locality_aware_nms'):
assert vertices.shape[1:] == (4, 2), 'Assumes bounding boxes given as (4, 2).'
# Sort bounding boxes row wise by sorting based on top coordinate of each bounding box.
vertices_sort_key = tf.reduce_min(vertices[..., 1], axis=1)
sorted_indices = tf.argsort(vertices_sort_key)
vertices = tf.gather(vertices, sorted_indices)
probs = tf.gather(probs, sorted_indices)
vertices_with_probs = tf.concat((tf.reshape(vertices, (-1, 8)), probs), axis=1)
def should_merge(p, g):
"""Compute whether two rectangles should merge."""
p = p[:, :8]
p = tf.reshape(p, (4, 2))
g = g[:, :8]
g = tf.reshape(g, (4, 2))
return intersection_over_union(p, g) >= iou_threshold
def weighted_merge(p, g):
"""Merge two rectangles."""
p_score = p[:, 8]
p_vertices = p[:, :8]
g_score = g[:, 8]
g_vertices = g[:, :8]
new_score = p_score + g_score
new_score = new_score[:, tf.newaxis]
new_vertices = (p_score * p_vertices + g_score * g_vertices) / (p_score + g_score)
return tf.concat((new_vertices, new_score), axis=1)
def locality_aware_nms_step(S, p, g):
g = g[tf.newaxis, ...]
p_is_none = tf.equal(tf.shape(p)[0], 0)
def p_is_none_branch():
# Set p to g.
return S, g
def p_is_not_none_branch():
def should_merge_branch():
# Set p to weighted merge of p and g.
return S, weighted_merge(p, g)
def should_not_merge_branch():
# Add p to S.
# Set p to g.
return tf.concat((S, p), axis=0), g
return tf.cond(should_merge(p, g), true_fn=should_merge_branch, false_fn=should_not_merge_branch)
return tf.cond(p_is_none, true_fn=p_is_none_branch, false_fn=p_is_not_none_branch)
S = forget_first_dimension(tf.zeros((0, 9)))
p = forget_first_dimension(tf.zeros((0, 9)))
# Fold over all rectangles.
S, p = tf.foldl(
lambda acc, curr: locality_aware_nms_step(*acc, curr),
vertices_with_probs,
initializer=(S, p),
parallel_iterations=1)
# No effect if p is empty.
S = tf.concat((S, p), axis=0)
S = _standard_nms(S, iou_threshold)
vertices = S[..., :8]
vertices = tf.reshape(vertices, (-1, 4, 2))
scores = S[..., 8]
return vertices, scores
def _standard_nms(S, iou_threshold):
"""Implements the standard NMS algorithm."""
S_empty = tf.equal(tf.shape(S)[0], 0)
def S_empty_branch():
return S
def S_not_empty_branch():
def nms_while_condition(_, candidates):
return tf.shape(candidates)[0] > 0
def nms_step(keep, candidates):
g = candidates[0]
g = g[tf.newaxis, :]
keep = tf.concat((keep, g), axis=0)
candidates = candidates[1:]
g = tf.reshape(g[:, :8], (4, 2))
def g_iou(x):
x = x[:8]
x = tf.reshape(x, (4, 2))
return intersection_over_union(g, x)
ious = tf.map_fn(g_iou, candidates)
candidates = tf.gather(candidates, tf.where(ious <= iou_threshold)[:, 0])
return keep, candidates
# Sort S based on scores.
sort_key = S[:, 8]
sorted_indices = tf.argsort(sort_key, direction='DESCENDING')
sorted_S = tf.gather(S, sorted_indices)
keep = forget_first_dimension(tf.zeros((0, 9)))
candidates = sorted_S
keep, _ = tf.while_loop(
cond=nms_while_condition,
body=nms_step,
loop_vars=(keep, candidates),
parallel_iterations=1)
return keep
S = tf.cond(S_empty, true_fn=S_empty_branch, false_fn=S_not_empty_branch)
return S
def polygon_area(vertices, n_corners, name=None):
"""
Compute the area of a convex polygon. Implements the algorithm described at
http://mathworld.wolfram.com/PolygonArea.html
Args:
vertices: 2-D tensor of shape (?, 2) where the last dimension is the xy coordinate.
n_corners: Number of vertices to consider.
name: Optional name of this operation.
Returns:
area: The computed polygon area.
"""
with tf.name_scope(name, 'polygon_area'):
vertices = vertices[:n_corners]
left = vertices
right = tf.roll(vertices, shift=-1, axis=0)
x_left = left[:, 0]
x_right = right[:, 0]
y_left = left[:, 1]
y_right = right[:, 1]
area = tf.reduce_sum(x_left * y_right - x_right * y_left)
area = area / 2
area = tf.abs(area)
return area
def rectangles_intersection_area(subject_rectangle, clip_rectangle, name=None):
"""
Compute the intersection area between two rectangles. Implements the Sutherland-Hodgman algorithm
described at https://en.wikipedia.org/wiki/Sutherland%E2%80%93Hodgman_algorithm
Args:
subject_rectangle: 2-D tensor of shape (4, 2) representing the four vertices of the subject rectangle
defined in clockwise order starting at the top left vertex.
clip_rectangle: 2-D tensor of shape (4, 2) representing the four vertices of the clip rectangle
defined in clockwise order starting at the top left vertex.
name: Optional name of this operation.
Returns:
area: The computed area of the intersection polygon.
"""
# TODO: See if this can be made faster by considering that we only have rectangles.
# Sutherland-Hodgman is a general algorithm for convex polygons.
with tf.name_scope(name, 'rectangles_intersection_area'):
def inside(p, edge):
"""Compute whether p is inside (right of) edge."""
p_x = p[0, 0]
p_y = p[0, 1]
e1_x = edge[0, 0]
e1_y = edge[0, 1]
e2_x = edge[1, 0]
e2_y = edge[1, 1]
return (e2_x - e1_x) * (p_y - e1_y) > (e2_y - e1_y) * (p_x - e1_x)
def compute_intersection(p1, p2, edge):
"""Compute the intersection point of a line segment and an infinite edge."""
p1_x = p1[0, 0]
p1_y = p1[0, 1]
p2_x = p2[0, 0]
p2_y = p2[0, 1]
e1_x = edge[0, 0]
e1_y = edge[0, 1]
e2_x = edge[1, 0]
e2_y = edge[1, 1]
dc = [e1_x - e2_x, e1_y - e2_y]
dp = [p2_x - p1_x, p2_y - p1_y]
n1 = e1_x * e2_y - e1_y * e2_x
n2 = p2_x * p1_y - p2_y * p1_x
n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0])
p = [(n1 * dp[0] - n2 * dc[0]) * n3, (n1 * dp[1] - n2 * dc[1]) * n3]
p = tf.stack(p, axis=0)
p = p[tf.newaxis, :]
return p
def clip_edge_loop_step(vertices, clip_edge):
"""Applied for each clip edge."""
def point_loop_step(output_vertices, s, e):
"""Applied at each point of the current vertex list."""
e = e[tf.newaxis, :]
def e_inside_clip_edge_branch():
output = output_vertices
def s_not_inside_clip_edge_branch():
return tf.concat((output, compute_intersection(s, e, clip_edge)), axis=0)
output = tf.cond(
tf.math.logical_not(inside(s, clip_edge)),
true_fn=s_not_inside_clip_edge_branch,
false_fn=lambda: output)
output = tf.concat((output, e), axis=0)
return output
def e_not_inside_clip_edge_branch():
output = output_vertices
def s_inside_clip_edge_branch():
return tf.concat((output, compute_intersection(s, e, clip_edge)), axis=0)
output = tf.cond(
inside(s, clip_edge),
true_fn=s_inside_clip_edge_branch,
false_fn=lambda: output)
return output
output_vertices = tf.cond(
inside(e, clip_edge),
true_fn=e_inside_clip_edge_branch,
false_fn=e_not_inside_clip_edge_branch)
return output_vertices, e
input_vertices = vertices
input_vertices_empty = tf.equal(tf.shape(input_vertices)[0], 0)
def input_vertices_empty_branch():
return input_vertices
def input_vertices_not_empty_branch():
s = input_vertices[-1]
s = s[tf.newaxis, :]
# Start with empty list of output vertices.
output_vertices = forget_first_dimension(tf.zeros((0, 2)))
output_vertices, _ = tf.foldl(
lambda acc, curr: point_loop_step(*acc, curr),
input_vertices,
initializer=(output_vertices, s),
parallel_iterations=1)
return output_vertices
output_vertices = tf.cond(
input_vertices_empty,
true_fn=input_vertices_empty_branch,
false_fn=input_vertices_not_empty_branch)
return output_vertices
clip_edges = tf.stack((clip_rectangle, tf.roll(clip_rectangle, shift=-1, axis=0)), axis=1)
initial_vertices = forget_first_dimension(subject_rectangle)
vertices = tf.foldl(
clip_edge_loop_step,
clip_edges,
initializer=initial_vertices,
parallel_iterations=1)
n_corners = tf.shape(vertices)[0]
area = polygon_area(vertices, n_corners)
return area
def intersection_over_union(a, b):
"""
Compute the intersection over union between two rectangles.
Args:
a: 2-D tensor of shape (4, 2) representing the four vertices of the first rectangle
defined in clockwise order starting at the top left vertex.
b: 2-D tensor of shape (4, 2) representing the four vertices of the second rectangle
defined in clockwise order starting at the top left vertex.
Returns:
iou: The ratio of intersection area over union area.
"""
a = tf.cast(a, tf.float32)
b = tf.cast(b, tf.float32)
# TODO: Add fast check if they intersect before computing areas.
def rectangle_area(r):
return polygon_area(r, 4)
intersection = rectangles_intersection_area(a, b)
union = rectangle_area(a) + rectangle_area(b) - intersection
return intersection / union
def forget_first_dimension(x):
"""
Utility function to "forget" the first dimension of a tensor.
Args:
x: N-D tensor.
Returns:
x: The same tensor but with the first dimension forgotten.
"""
g = x.graph
s = x.shape.as_list()[1:]
x = tf.placeholder_with_default(x, [None] + s)
g.prevent_feeding(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment