Last active
January 21, 2022 09:50
-
-
Save Ending2015a/215375b470dcdd50de3c9b2252337888 to your computer and use it in GitHub Desktop.
Scatter N-D, TensorFlow 2.0 implementation, this function can be used for Active Neural SLAM to project depth maps to top-down height map
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
def masked_scatter_nd_max(canvas, indices, values, mask=None): | |
''' | |
Scatters vector values with dim=v over an n-dim canvas, | |
For projecting to an image-type canvas, `dn` = `d2`, that is, the | |
shape of the canvas is (..., d1, d2, v) or (..., h, w, v), where | |
`v` is the depth of the vector values, i.e. `v`=3 for projecting | |
point cloud to a height map (xyz coords). In this case, `values` | |
is the flattened batch point clouds with `N` points. `n` for | |
`indices` is the number of dimensions of the canvas, which is `n`=2, | |
i.e. (h, w) coords. | |
Args: | |
canvas (tf.Tensor): Canvas with shape (..., d1, ..., dn, v) | |
indices (tf.Tensor): (d1, ..., dn) coordinates, where each value scattered, | |
with shape (..., N, n) | |
values (tf.Tensor): Vector values with shape (..., N, v) | |
mask (tf.Tensor): Mask, where valid area=True, with shape(..., N). | |
Returns: | |
tf.Tensor, updated canvas, (batch, ..., d1, ..., dn, v) | |
tf.Tensor, final masks, (batch, ..., N) | |
''' | |
# default mask | |
if mask is None: | |
mask = tf.ones(tf.shape(values)[:-1], dtype=tf.bool) | |
# converts to tensors | |
canvas = tf.convert_to_tensor(canvas) | |
indices = tf.convert_to_tensor(indices) | |
values = tf.convert_to_tensor(values) | |
mask = tf.convert_to_tensor(mask) | |
indices = tf.cast(indices, dtype=tf.int32) | |
mask = tf.cast(mask, dtype=tf.bool) | |
with tf.control_dependencies( | |
[ | |
tf.debugging.assert_rank_at_least(canvas, 3, | |
message="`canvas` must be at least 3-D tensor"), | |
tf.debugging.assert_rank_at_least(indices, 3, | |
message="`indices` must be at least 3-D tensor"), | |
tf.debugging.assert_rank_at_least(values, 3, | |
message="`values` must be at least 3-D tensor"), | |
tf.debugging.assert_rank_at_least(mask, 2, | |
message="`mask` must be at least 2-D tensor") | |
] | |
): | |
# get dimensions | |
n = indices.get_shape()[-1] | |
v = canvas.get_shape()[-1] | |
N = mask.get_shape()[-1] | |
d1_dn = canvas.get_shape()[-n-1:-1] | |
batch_dims = tf.shape(canvas)[:-n-1] | |
ind_dtype = indices.dtype | |
# find valid areas | |
valid_areas = [mask] | |
for i in reversed(range(n)): | |
di = indices[..., i] | |
valid_areas.extend(( | |
di < d1_dn[i], | |
di >= 0 | |
)) | |
valid_area = tf.stack(valid_areas, axis=0) | |
mask = tf.math.reduce_all(valid_area, axis=0) | |
# generate batch indices for tensor_scatter_nd_max | |
# slices each dimension, say "..." = (b1, b2, b3), len(...) = 3 | |
dim_slices = [dim_slice for dim_slice in | |
tf.split(batch_dims, len(batch_dims))] | |
# generates tf.range for each dimension | |
batch_inds = [tf.range(dim_slice, dtype=ind_dtype) | |
for dim_slice in dim_slices] | |
# generates meshgrid, dimensions in reversed order, | |
# default: (b3, b2, b1) -> (b1, b2, b3) | |
batch_ranges = tf.meshgrid(*batch_inds, indexing='ij') | |
# stacks meshgrid to form coordinates (b1, b2, b3, 3) | |
batch_inds = tf.stack(batch_ranges, axis=-1) | |
# expand dimensions (b1, b2, b3, 3) -> (b1, b2, b3, 1, 3), | |
batch_inds = tf.expand_dims(batch_inds, axis=-2) | |
# repeats (b1, b2, b3, 1, 3) -> (b1, b2, b3, N, 3) | |
batch_inds = tf.tile(batch_inds, (1,)*len(batch_dims) + (N, 1)) | |
# concatenates x, y indices behind the batch_inds | |
# (b1, b2, b3, N, 3+n) | |
batch_inds = tf.concat((batch_inds, indices), axis=-1) | |
ind_depth = batch_inds.get_shape()[-1] | |
# make dummy canvas (..., d1, ..., dn +1, v) | |
# make a dummy channel (..., d1, ..., d(n-1), 0, v) to store invalid values | |
# in TensorFlow > 2.4.0 we can easily set index to -1 for invalid values | |
# because tensor_scatter_nd_max automatically filters out invalid | |
# values which has out-of-range indices. However, in TensorFlow < | |
# 2.4.0, tensor_scatter_nd_max throws indices out-of-range exceptions | |
# so we need to create a dummy channel and scatters invalid values into | |
# this dummy channel. | |
dummy_channel = tf.concat((batch_dims, d1_dn[:-1], (1, v)), axis=0) | |
dummy_channel = tf.fill(dummy_channel, -np.inf) | |
dummy_canvas = tf.concat((dummy_channel, canvas), axis=-2) | |
dummy_ind_shift = tf.concat(( | |
(0,)*len(batch_dims), | |
(0,)*(len(d1_dn)-1), # d1 ~ d(n-1) | |
(1,)), axis=0) | |
# filter out invalid range | |
# replace indices to 0, which will be scattered into dummy channel | |
# broadcast masks (..., N) to (..., N, ind_depth) | |
broadcast_mask = tf.repeat(tf.expand_dims(mask, axis=-1), ind_depth, axis=-1) | |
# replace all indices of invalid areas to -1 | |
batch_inds = tf.where(broadcast_mask, batch_inds+dummy_ind_shift, 0) | |
# scatter | |
new_canvas = tf.tensor_scatter_nd_max(dummy_canvas, batch_inds, values) | |
# slice off dummy channel (..., 0, v) | |
new_canvas = new_canvas[..., 1:, :] | |
return new_canvas, mask | |
# dummy point cloud (batch, channel, height, width, xyz) = (1, 2, 5, 5, 3) | |
# unit: meter | |
values = np.array([[[[[ 0.92871926, -0.39209746, 0.12709531], | |
[ 0.37437783, -0.25560278, -2.1768249 ], | |
[-0.21010604, -0.87326627, -0.60568358], | |
[ 0.11826354, 0.72192535, -1.96805051], | |
[ 0.07642954, 0.02877341, -0.52130058]], | |
[[-1.07883079, -1.09864275, -1.48197995], | |
[-0.52746128, 0.64207189, 0.95996284], | |
[ 0.29431672, -0.79195994, -0.29312353], | |
[-0.58089971, 0.05356699, -0.18195914], | |
[ 0.63448274, -0.64338309, -0.18980063]], | |
[[-0.39415563, -2.61698209, -1.60855244], | |
[-1.85730103, 1.96747892, -1.36135689], | |
[ 0.17008098, 0.69992018, -1.69435467], | |
[-0.42376153, 0.34204736, 0.3173328 ], | |
[ 1.31884528, -1.28284411, -0.06323276]], | |
[[ 1.01415592, -1.56410225, 2.55963775], | |
[-0.1527702 , -1.27259893, 0.97006746], | |
[ 0.46391498, -0.82628582, -1.22322484], | |
[ 0.51598177, -0.90726735, -2.15268906], | |
[ 0.88671569, 0.34563078, 0.54024559]], | |
[[-1.20541569, -0.27154192, -0.05633884], | |
[-0.36523929, -1.17248391, 0.84481116], | |
[-1.03267173, -0.3065308 , -0.35678831], | |
[ 0.92520116, -0.8984506 , -0.58580828], | |
[-0.62473293, -0.74235885, -0.72037534]]], | |
[[[ 0.09297083, 0.98570852, 1.13650902], | |
[ 0.81261274, 0.21577615, -0.80296376], | |
[ 1.39902247, -0.41790638, 0.37105384], | |
[-0.235837 , 1.14946586, -0.46826193], | |
[ 0.89406117, -0.81903676, -1.40690595]], | |
[[-1.13937087, -0.81807408, 0.0697723 ], | |
[-0.0718852 , -0.52776485, -1.79533604], | |
[ 0.56097385, 0.26405042, 0.07248514], | |
[-0.51417208, 1.28195223, -1.60939298], | |
[-0.1779261 , 0.14759517, -0.79710853]], | |
[[ 1.07133254, -0.86649908, 1.10818405], | |
[ 0.51709258, 0.16462324, -0.10645144], | |
[-0.94297979, 0.23160525, -1.00794647], | |
[ 0.05334653, 1.0522464 , -0.6964805 ], | |
[ 0.97591096, -0.2690103 , -1.33586831]], | |
[[ 0.02043337, -1.67731703, 0.72714383], | |
[ 0.7053991 , -0.32442375, -0.41602061], | |
[ 1.01215432, 2.43477928, -0.86891597], | |
[ 1.5247537 , -1.86446265, 0.29876436], | |
[-2.26656319, 1.12710737, 2.89601227]], | |
[[ 0.59182888, 0.29882975, -0.16293282], | |
[-1.09208092, -2.08845169, 2.17915906], | |
[ 0.48356899, 0.22009589, 0.28158253], | |
[ 0.16641354, 0.36653133, 0.49896538], | |
[-0.34871221, -0.56461655, 1.49807201]]]]], dtype=np.float32) + 2. | |
# (1, 2, 5, 5) | |
x = values[..., 0] | |
y = values[..., 1] | |
z = values[..., 2] | |
map_res = 1.0 # map resolutions (unit: meter per cell) | |
map_size = 3 # map cells (map_size by map_size map) | |
# quantize point cloud (1, 2, 5, 5) | |
z_bin = (-z/map_res + (map_size-1)).astype(np.int64) | |
x_bin = (x/map_res + (map_size-1)/2).astype(np.int64) | |
# filter out invalid areas (indices out of the map range) | |
isvalid = np.stack(( | |
z_bin >= 0, z_bin < map_size, x_bin >= 0, x_bin < map_size | |
), axis=0) | |
isvalid = np.all(isvalid, axis=0) # (1, 2, 5, 5) | |
# create empty map (canvas) | |
canvas = np.zeros((1, 2, 3, 3, 3) ,dtype=np.float32) # (1, 2, 3, 3, 3) | |
# combine coordinates | |
indices = np.stack((z_bin, x_bin), axis=-1) # (1, 2, 5, 5, 2) | |
flat_indices = tf.reshape(indices, (1, 2, 25, 2)) | |
flat_values = tf.reshape(values, (1, 2, 25, 3)) | |
# scatter | |
new_canvas, mask = masked_scatter_nd_max(canvas, flat_indices, flat_values) | |
print(z_bin) | |
print(x_bin) | |
print(isvalid) | |
print(mask) | |
indices[np.logical_not(isvalid)] = [-1, -1] | |
indices = indices.reshape(1, 2, -1, 2) | |
points = np.unique(indices[0, 0], axis=0) | |
print(points) | |
points = np.unique(indices[0, 1], axis=0) | |
print(points) | |
print(y) | |
print(new_canvas[..., 1]) | |
''' | |
Expecting results: | |
tf.Tensor( | |
[[[[ -inf 1.728458 2.642072 ] | |
[ -inf 3.9674788 -0.616982 ] | |
[ -inf -inf -inf]] | |
[[ -inf 1.1819259 3.149466 ] | |
[ -inf -inf 3.2819524] | |
[ -inf -inf -inf]]]], shape=(1, 2, 3, 3), dtype=float32) | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
PyTorch implementation here: https://gist.github.com/Ending2015a/b034ebbedc55fec1d8ec3b7230a95f1e