Skip to content

Instantly share code, notes, and snippets.

@Ending2015a
Created January 21, 2022 09:49
Show Gist options
  • Save Ending2015a/b034ebbedc55fec1d8ec3b7230a95f1e to your computer and use it in GitHub Desktop.
Save Ending2015a/b034ebbedc55fec1d8ec3b7230a95f1e to your computer and use it in GitHub Desktop.
Scatter N-D, PyTorch implementation, this function can be used for Active Neural SLAM to project depth maps to top-down height map
import numpy as np
import torch
from torch import nn
import torch_scatter # pip install torch-scatter
def ravel_index(index, shape):
"""Ravel multi-dimensional indices to 1D index
similar to np.ravel_multi_index
Args:
index (torch.tensor): indices in reversed order dn, ..., d1, with shape (..., n)
shape (tuple): dn, ..., d1
"""
index = torch.tensor(index, dtype=torch.int64)
shape = torch.tensor((1,) + shape[::-1], dtype=torch.int64) # =(1, d1, d1*d2, ..., d1*...*dn)
shape = torch.cumprod(shape, dim=0)[:-1].flip(0) # =(d1*...*dn-1, ..., d1*d2, d1, 1)
index = (index * shape).sum(dim=-1) # (...,)
return index
def masked_scatter_nd_max(canvas, indices, values, mask=None, fill_value=-np.inf):
'''
Scatters vector values with dim=v over an n-dim canvas,
For projecting to an image-type canvas, `dn` = `d2`, that is, the
shape of the canvas is (..., d1, d2, v) or (..., h, w, v), where
`v` is the depth of the vector values, i.e. `v`=3 for projecting
point cloud to a height map (xyz coords). In this case, `values`
is the flattened batch point clouds with `N` points. `n` for
`indices` is the number of dimensions of the canvas, which is `n`=2,
i.e. (h, w) coords.
Args:
canvas (tf.Tensor): Canvas with shape (..., d1, ..., dn, v)
indices (tf.Tensor): (d1, ..., dn) coordinates, where each value scattered,
with shape (..., N, n)
values (tf.Tensor): Vector values with shape (..., N, v)
mask (tf.Tensor): Mask, where valid area=True, with shape (..., N).
Returns:
tf.Tensor, updated canvas, (batch, ..., d1, ..., dn, v)
tf.Tensor, final masks, (batch, ..., N)
'''
# default mask
if mask is None:
mask = torch.ones(values.shape[:-1], dtype=torch.bool)
# converts to tensors
canvas = torch.tensor(canvas)
indices = torch.tensor(indices).to(dtype=torch.int64)
values = torch.tensor(values)
mask = torch.tensor(mask).to(dtype=torch.bool)
# get dimensions
n = indices.shape[-1]
v = canvas.shape[-1]
N = mask.shape[-1]
d1_dn = canvas.shape[-n-1:-1]
batch_dims = canvas.shape[:-n-1]
ind_dtype = indices.dtype
# find valid areas
valid_areas = [mask]
for i in reversed(range(n)):
di = indices[..., i]
valid_areas.extend((
di < d1_dn[i],
di >= 0
))
valid_area = torch.stack(valid_areas, dim=0)
mask = valid_area.all(dim=0)
# dummy index for invalid values (0, ..., 0, -1)
indices[..., :][~mask] = 0
indices[..., -1][~mask] = -1
# flatten canvas, indices, mask
flat_canvas = canvas.view(*batch_dims, -1, v) # (..., d1*...*dn, v)
flat_indices = ravel_index(indices, d1_dn) # convert n-d indices to 1-d indices (..., N)
flat_mask = mask # (..., N)
flat_values = values # (..., N, v)
# create dummy channel to store invalid values
dummy_channel = torch.zeros_like(flat_canvas[..., 0:1, :])
dummy_shift = 1 # shift dummy index from (0, ..., 0, -1) to (0, ..., 0, 0)
flat_canvas = torch.cat((dummy_channel, flat_canvas), dim=-2) # (..., 1 + d1*...*dn, v)
flat_indices = flat_indices + dummy_shift
flat_canvas.fill_(fill_value)
torch_scatter.scatter_max(flat_values, flat_indices, dim=-2, out=flat_canvas)
flat_canvas = flat_canvas[..., 1:, :]
canvas = flat_canvas.view(canvas.shape)
mask = torch.isinf(canvas)
return canvas, mask
# dummy point cloud (batch, channel, height, width, xyz) = (1, 2, 5, 5, 3)
# unit: meter
values = np.array([[[[[ 0.92871926, -0.39209746, 0.12709531],
[ 0.37437783, -0.25560278, -2.1768249 ],
[-0.21010604, -0.87326627, -0.60568358],
[ 0.11826354, 0.72192535, -1.96805051],
[ 0.07642954, 0.02877341, -0.52130058]],
[[-1.07883079, -1.09864275, -1.48197995],
[-0.52746128, 0.64207189, 0.95996284],
[ 0.29431672, -0.79195994, -0.29312353],
[-0.58089971, 0.05356699, -0.18195914],
[ 0.63448274, -0.64338309, -0.18980063]],
[[-0.39415563, -2.61698209, -1.60855244],
[-1.85730103, 1.96747892, -1.36135689],
[ 0.17008098, 0.69992018, -1.69435467],
[-0.42376153, 0.34204736, 0.3173328 ],
[ 1.31884528, -1.28284411, -0.06323276]],
[[ 1.01415592, -1.56410225, 2.55963775],
[-0.1527702 , -1.27259893, 0.97006746],
[ 0.46391498, -0.82628582, -1.22322484],
[ 0.51598177, -0.90726735, -2.15268906],
[ 0.88671569, 0.34563078, 0.54024559]],
[[-1.20541569, -0.27154192, -0.05633884],
[-0.36523929, -1.17248391, 0.84481116],
[-1.03267173, -0.3065308 , -0.35678831],
[ 0.92520116, -0.8984506 , -0.58580828],
[-0.62473293, -0.74235885, -0.72037534]]],
[[[ 0.09297083, 0.98570852, 1.13650902],
[ 0.81261274, 0.21577615, -0.80296376],
[ 1.39902247, -0.41790638, 0.37105384],
[-0.235837 , 1.14946586, -0.46826193],
[ 0.89406117, -0.81903676, -1.40690595]],
[[-1.13937087, -0.81807408, 0.0697723 ],
[-0.0718852 , -0.52776485, -1.79533604],
[ 0.56097385, 0.26405042, 0.07248514],
[-0.51417208, 1.28195223, -1.60939298],
[-0.1779261 , 0.14759517, -0.79710853]],
[[ 1.07133254, -0.86649908, 1.10818405],
[ 0.51709258, 0.16462324, -0.10645144],
[-0.94297979, 0.23160525, -1.00794647],
[ 0.05334653, 1.0522464 , -0.6964805 ],
[ 0.97591096, -0.2690103 , -1.33586831]],
[[ 0.02043337, -1.67731703, 0.72714383],
[ 0.7053991 , -0.32442375, -0.41602061],
[ 1.01215432, 2.43477928, -0.86891597],
[ 1.5247537 , -1.86446265, 0.29876436],
[-2.26656319, 1.12710737, 2.89601227]],
[[ 0.59182888, 0.29882975, -0.16293282],
[-1.09208092, -2.08845169, 2.17915906],
[ 0.48356899, 0.22009589, 0.28158253],
[ 0.16641354, 0.36653133, 0.49896538],
[-0.34871221, -0.56461655, 1.49807201]]]]], dtype=np.float32) + 2.
# (1, 2, 5, 5)
x = values[..., 0]
y = values[..., 1]
z = values[..., 2]
map_res = 1.0 # map resolutions (unit: meter per cell)
map_size = 3 # map cells (map_size by map_size map)
# quantize point cloud (1, 2, 5, 5)
z_bin = (-z/map_res + (map_size-1)).astype(np.int64)
x_bin = (x/map_res + (map_size-1)/2).astype(np.int64)
# filter out invalid areas (indices out of the map range)
isvalid = np.stack((
z_bin >= 0, z_bin < map_size, x_bin >= 0, x_bin < map_size
), axis=0)
isvalid = np.all(isvalid, axis=0) # (1, 2, 5, 5)
# create empty map (canvas)
canvas = np.zeros((1, 2, 3, 3, 3) ,dtype=np.float32) # (1, 2, 3, 3, 3)
# combine coordinates
indices = np.stack((z_bin, x_bin), axis=-1) # (1, 2, 5, 5, 2)
flat_indices = torch.tensor(indices).view(1, 2, 25, 2)
flat_values = torch.tensor(values).view(1, 2, 25, 3)
# scatter
new_canvas, mask = masked_scatter_nd_max(canvas, flat_indices, flat_values)
print('z_bin:', z_bin)
print('x_bin:', x_bin)
print('isvalid:', isvalid)
print('mask:', mask)
indices[np.logical_not(isvalid)] = [-1, -1]
indices = indices.reshape(1, 2, -1, 2)
points = np.unique(indices[0, 0], axis=0)
print('points:', points)
points = np.unique(indices[0, 1], axis=0)
print('points:', points)
print('y:', y)
print('new_canvas:', new_canvas[..., 1])
'''
Expecting results:
tensor([[[[ -inf, 1.7285, 2.6421],
[ -inf, 3.9675, -0.6170],
[ -inf, -inf, -inf]],
[[ -inf, 1.1819, 3.1495],
[ -inf, -inf, 3.2820],
[ -inf, -inf, -inf]]]])
'''
@Ending2015a
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment