Skip to content

Instantly share code, notes, and snippets.

@petered
Created August 10, 2022 22:43
Show Gist options
  • Save petered/5bb97f751bb59475fdfaa60d771e4663 to your computer and use it in GitHub Desktop.
Save petered/5bb97f751bb59475fdfaa60d771e4663 to your computer and use it in GitHub Desktop.
An algorithm for computing a series of bounding boxes from a boolean mask
import tensorflow as tf
def tf_mask_to_boxes(mask, insert_fake_first_box: bool = False):
"""
Convert a boolean mask to a series of bounding boxes around each segment.
Warning: VERY SLOW (slower than pure python version of same algorithm)
:param mask: A (HxW) boolean mask
:param insert_fake_first_box: TFLite seems to have a bug
(see issue) https://github.com/tensorflow/tensorflow/issues/57084 where
it cannot return zero-boxes. So this flag insers a fake box of all zeros first.
:returns: A (Nx4) array of (Left, Top, Right, Bottom) box bounds.
"""
max_ix = tf.reduce_max(mask.shape)
def drag_index_right(last_ix, this_col_inputs):
this_mask_cell, above_cell_index, this_col_ix = this_col_inputs
still_active = this_mask_cell or ((last_ix!=max_ix) and (above_cell_index!=max_ix))
new_last_ix = tf.minimum(tf.minimum(last_ix, this_col_ix), above_cell_index) if still_active else max_ix
return new_last_ix
def drag_row_down(above_row, this_mask_row):
return tf.scan(
drag_index_right,
elems=(this_mask_row, above_row, tf.range(len(this_mask_row))),
initializer=max_ix
)
def compute_indices_down(mask_):
horizontal_index_grid = tf.scan(
drag_row_down,
elems=mask_,
initializer=tf.fill((tf.shape(mask_)[1], ), value=max_ix)
)
active_mask = horizontal_index_grid != max_ix
stop_mask = tf.concat([active_mask[:, :-1] & ~active_mask[:, 1:], active_mask[:, -1][:, None]], axis=1)
return horizontal_index_grid, stop_mask
x_ixs, x_stop_mask = compute_indices_down(mask)
y_ixs, y_stop_mask = (tf.transpose(t) for t in compute_indices_down(tf.transpose(mask)))
corner_mask = x_stop_mask & y_stop_mask
yx_stops_ixs = tf.cast(tf.where(corner_mask), tf.int32)
x_starts = tf.gather_nd(x_ixs, yx_stops_ixs)
y_starts = tf.gather_nd(y_ixs, yx_stops_ixs)
ltrb_boxes = tf.concat([x_starts[:, None], y_starts[:, None], yx_stops_ixs[:, 1][:, None] + 1, yx_stops_ixs[:, 0][:, None] + 1], axis=1)
if insert_fake_first_box:
ltrb_boxes = tf.concat([tf.zeros((1, 4), dtype=tf.int32), ltrb_boxes], axis=0)
return ltrb_boxes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment