Skip to content

Instantly share code, notes, and snippets.

@Ending2015a
Last active January 21, 2022 09:50
Show Gist options
  • Save Ending2015a/215375b470dcdd50de3c9b2252337888 to your computer and use it in GitHub Desktop.
Save Ending2015a/215375b470dcdd50de3c9b2252337888 to your computer and use it in GitHub Desktop.
Scatter N-D, TensorFlow 2.0 implementation, this function can be used for Active Neural SLAM to project depth maps to top-down height map
import tensorflow as tf
import numpy as np
def masked_scatter_nd_max(canvas, indices, values, mask=None):
'''
Scatters vector values with dim=v over an n-dim canvas,
For projecting to an image-type canvas, `dn` = `d2`, that is, the
shape of the canvas is (..., d1, d2, v) or (..., h, w, v), where
`v` is the depth of the vector values, i.e. `v`=3 for projecting
point cloud to a height map (xyz coords). In this case, `values`
is the flattened batch point clouds with `N` points. `n` for
`indices` is the number of dimensions of the canvas, which is `n`=2,
i.e. (h, w) coords.
Args:
canvas (tf.Tensor): Canvas with shape (..., d1, ..., dn, v)
indices (tf.Tensor): (d1, ..., dn) coordinates, where each value scattered,
with shape (..., N, n)
values (tf.Tensor): Vector values with shape (..., N, v)
mask (tf.Tensor): Mask, where valid area=True, with shape(..., N).
Returns:
tf.Tensor, updated canvas, (batch, ..., d1, ..., dn, v)
tf.Tensor, final masks, (batch, ..., N)
'''
# default mask
if mask is None:
mask = tf.ones(tf.shape(values)[:-1], dtype=tf.bool)
# converts to tensors
canvas = tf.convert_to_tensor(canvas)
indices = tf.convert_to_tensor(indices)
values = tf.convert_to_tensor(values)
mask = tf.convert_to_tensor(mask)
indices = tf.cast(indices, dtype=tf.int32)
mask = tf.cast(mask, dtype=tf.bool)
with tf.control_dependencies(
[
tf.debugging.assert_rank_at_least(canvas, 3,
message="`canvas` must be at least 3-D tensor"),
tf.debugging.assert_rank_at_least(indices, 3,
message="`indices` must be at least 3-D tensor"),
tf.debugging.assert_rank_at_least(values, 3,
message="`values` must be at least 3-D tensor"),
tf.debugging.assert_rank_at_least(mask, 2,
message="`mask` must be at least 2-D tensor")
]
):
# get dimensions
n = indices.get_shape()[-1]
v = canvas.get_shape()[-1]
N = mask.get_shape()[-1]
d1_dn = canvas.get_shape()[-n-1:-1]
batch_dims = tf.shape(canvas)[:-n-1]
ind_dtype = indices.dtype
# find valid areas
valid_areas = [mask]
for i in reversed(range(n)):
di = indices[..., i]
valid_areas.extend((
di < d1_dn[i],
di >= 0
))
valid_area = tf.stack(valid_areas, axis=0)
mask = tf.math.reduce_all(valid_area, axis=0)
# generate batch indices for tensor_scatter_nd_max
# slices each dimension, say "..." = (b1, b2, b3), len(...) = 3
dim_slices = [dim_slice for dim_slice in
tf.split(batch_dims, len(batch_dims))]
# generates tf.range for each dimension
batch_inds = [tf.range(dim_slice, dtype=ind_dtype)
for dim_slice in dim_slices]
# generates meshgrid, dimensions in reversed order,
# default: (b3, b2, b1) -> (b1, b2, b3)
batch_ranges = tf.meshgrid(*batch_inds, indexing='ij')
# stacks meshgrid to form coordinates (b1, b2, b3, 3)
batch_inds = tf.stack(batch_ranges, axis=-1)
# expand dimensions (b1, b2, b3, 3) -> (b1, b2, b3, 1, 3),
batch_inds = tf.expand_dims(batch_inds, axis=-2)
# repeats (b1, b2, b3, 1, 3) -> (b1, b2, b3, N, 3)
batch_inds = tf.tile(batch_inds, (1,)*len(batch_dims) + (N, 1))
# concatenates x, y indices behind the batch_inds
# (b1, b2, b3, N, 3+n)
batch_inds = tf.concat((batch_inds, indices), axis=-1)
ind_depth = batch_inds.get_shape()[-1]
# make dummy canvas (..., d1, ..., dn +1, v)
# make a dummy channel (..., d1, ..., d(n-1), 0, v) to store invalid values
# in TensorFlow > 2.4.0 we can easily set index to -1 for invalid values
# because tensor_scatter_nd_max automatically filters out invalid
# values which has out-of-range indices. However, in TensorFlow <
# 2.4.0, tensor_scatter_nd_max throws indices out-of-range exceptions
# so we need to create a dummy channel and scatters invalid values into
# this dummy channel.
dummy_channel = tf.concat((batch_dims, d1_dn[:-1], (1, v)), axis=0)
dummy_channel = tf.fill(dummy_channel, -np.inf)
dummy_canvas = tf.concat((dummy_channel, canvas), axis=-2)
dummy_ind_shift = tf.concat((
(0,)*len(batch_dims),
(0,)*(len(d1_dn)-1), # d1 ~ d(n-1)
(1,)), axis=0)
# filter out invalid range
# replace indices to 0, which will be scattered into dummy channel
# broadcast masks (..., N) to (..., N, ind_depth)
broadcast_mask = tf.repeat(tf.expand_dims(mask, axis=-1), ind_depth, axis=-1)
# replace all indices of invalid areas to -1
batch_inds = tf.where(broadcast_mask, batch_inds+dummy_ind_shift, 0)
# scatter
new_canvas = tf.tensor_scatter_nd_max(dummy_canvas, batch_inds, values)
# slice off dummy channel (..., 0, v)
new_canvas = new_canvas[..., 1:, :]
return new_canvas, mask
# dummy point cloud (batch, channel, height, width, xyz) = (1, 2, 5, 5, 3)
# unit: meter
values = np.array([[[[[ 0.92871926, -0.39209746, 0.12709531],
[ 0.37437783, -0.25560278, -2.1768249 ],
[-0.21010604, -0.87326627, -0.60568358],
[ 0.11826354, 0.72192535, -1.96805051],
[ 0.07642954, 0.02877341, -0.52130058]],
[[-1.07883079, -1.09864275, -1.48197995],
[-0.52746128, 0.64207189, 0.95996284],
[ 0.29431672, -0.79195994, -0.29312353],
[-0.58089971, 0.05356699, -0.18195914],
[ 0.63448274, -0.64338309, -0.18980063]],
[[-0.39415563, -2.61698209, -1.60855244],
[-1.85730103, 1.96747892, -1.36135689],
[ 0.17008098, 0.69992018, -1.69435467],
[-0.42376153, 0.34204736, 0.3173328 ],
[ 1.31884528, -1.28284411, -0.06323276]],
[[ 1.01415592, -1.56410225, 2.55963775],
[-0.1527702 , -1.27259893, 0.97006746],
[ 0.46391498, -0.82628582, -1.22322484],
[ 0.51598177, -0.90726735, -2.15268906],
[ 0.88671569, 0.34563078, 0.54024559]],
[[-1.20541569, -0.27154192, -0.05633884],
[-0.36523929, -1.17248391, 0.84481116],
[-1.03267173, -0.3065308 , -0.35678831],
[ 0.92520116, -0.8984506 , -0.58580828],
[-0.62473293, -0.74235885, -0.72037534]]],
[[[ 0.09297083, 0.98570852, 1.13650902],
[ 0.81261274, 0.21577615, -0.80296376],
[ 1.39902247, -0.41790638, 0.37105384],
[-0.235837 , 1.14946586, -0.46826193],
[ 0.89406117, -0.81903676, -1.40690595]],
[[-1.13937087, -0.81807408, 0.0697723 ],
[-0.0718852 , -0.52776485, -1.79533604],
[ 0.56097385, 0.26405042, 0.07248514],
[-0.51417208, 1.28195223, -1.60939298],
[-0.1779261 , 0.14759517, -0.79710853]],
[[ 1.07133254, -0.86649908, 1.10818405],
[ 0.51709258, 0.16462324, -0.10645144],
[-0.94297979, 0.23160525, -1.00794647],
[ 0.05334653, 1.0522464 , -0.6964805 ],
[ 0.97591096, -0.2690103 , -1.33586831]],
[[ 0.02043337, -1.67731703, 0.72714383],
[ 0.7053991 , -0.32442375, -0.41602061],
[ 1.01215432, 2.43477928, -0.86891597],
[ 1.5247537 , -1.86446265, 0.29876436],
[-2.26656319, 1.12710737, 2.89601227]],
[[ 0.59182888, 0.29882975, -0.16293282],
[-1.09208092, -2.08845169, 2.17915906],
[ 0.48356899, 0.22009589, 0.28158253],
[ 0.16641354, 0.36653133, 0.49896538],
[-0.34871221, -0.56461655, 1.49807201]]]]], dtype=np.float32) + 2.
# (1, 2, 5, 5)
x = values[..., 0]
y = values[..., 1]
z = values[..., 2]
map_res = 1.0 # map resolutions (unit: meter per cell)
map_size = 3 # map cells (map_size by map_size map)
# quantize point cloud (1, 2, 5, 5)
z_bin = (-z/map_res + (map_size-1)).astype(np.int64)
x_bin = (x/map_res + (map_size-1)/2).astype(np.int64)
# filter out invalid areas (indices out of the map range)
isvalid = np.stack((
z_bin >= 0, z_bin < map_size, x_bin >= 0, x_bin < map_size
), axis=0)
isvalid = np.all(isvalid, axis=0) # (1, 2, 5, 5)
# create empty map (canvas)
canvas = np.zeros((1, 2, 3, 3, 3) ,dtype=np.float32) # (1, 2, 3, 3, 3)
# combine coordinates
indices = np.stack((z_bin, x_bin), axis=-1) # (1, 2, 5, 5, 2)
flat_indices = tf.reshape(indices, (1, 2, 25, 2))
flat_values = tf.reshape(values, (1, 2, 25, 3))
# scatter
new_canvas, mask = masked_scatter_nd_max(canvas, flat_indices, flat_values)
print(z_bin)
print(x_bin)
print(isvalid)
print(mask)
indices[np.logical_not(isvalid)] = [-1, -1]
indices = indices.reshape(1, 2, -1, 2)
points = np.unique(indices[0, 0], axis=0)
print(points)
points = np.unique(indices[0, 1], axis=0)
print(points)
print(y)
print(new_canvas[..., 1])
'''
Expecting results:
tf.Tensor(
[[[[ -inf 1.728458 2.642072 ]
[ -inf 3.9674788 -0.616982 ]
[ -inf -inf -inf]]
[[ -inf 1.1819259 3.149466 ]
[ -inf -inf 3.2819524]
[ -inf -inf -inf]]]], shape=(1, 2, 3, 3), dtype=float32)
'''
@Ending2015a
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment