Skip to content

Instantly share code, notes, and snippets.

@Erol444
Created March 12, 2022 17:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Erol444/0a9f4ae505ef9208edb144e0237f1050 to your computer and use it in GitHub Desktop.
Save Erol444/0a9f4ae505ef9208edb144e0237f1050 to your computer and use it in GitHub Desktop.
Kornia depth_to_3d function
def depth_to_3d(depth: torch.Tensor, camera_matrix: torch.Tensor) -> torch.Tensor:
"""Compute a 3d point per pixel given its depth value and the camera intrinsics.
Args:
depth: image tensor containing a depth value per pixel with shape :math:`(B, 1, H, W)`.
camera_matrix: tensor containing the camera intrinsics with shape :math:`(B, 3, 3)`.
normalize_points: whether to normalise the pointcloud. This must be set to `True` when the depth is
represented as the Euclidean ray length from the camera position.
Return:
tensor with a 3d point per pixel of the same resolution as the input :math:`(B, 3, H, W)`.
Example:
>>> depth = torch.rand(1, 1, 4, 4)
>>> K = torch.eye(3)[None]
>>> depth_to_3d(depth, K).shape
torch.Size([1, 3, 4, 4])
"""
if not isinstance(depth, torch.Tensor):
raise TypeError(f"Input depht type is not a torch.Tensor. Got {type(depth)}.")
if not (len(depth.shape) == 4 and depth.shape[-3] == 1):
raise ValueError(f"Input depth musth have a shape (B, 1, H, W). Got: {depth.shape}")
if not isinstance(camera_matrix, torch.Tensor):
raise TypeError(f"Input camera_matrix type is not a torch.Tensor. " f"Got {type(camera_matrix)}.")
if not (len(camera_matrix.shape) == 3 and camera_matrix.shape[-2:] == (3, 3)):
raise ValueError(f"Input camera_matrix must have a shape (B, 3, 3). " f"Got: {camera_matrix.shape}.")
# create base coordinates grid
_, _, height, width = depth.shape
xs: torch.Tensor = torch.linspace(0, width - 1, width, dtype=torch.float32)
ys: torch.Tensor = torch.linspace(0, height - 1, height, dtype=torch.float32)
# generate grid by stacking coordinates
base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=-1) # WxHx2
points_2d: torch.Tensor = base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
points_2d = points_2d.to(depth.device).to(depth.dtype)
# depth should come in Bx1xHxW
points_depth: torch.Tensor = depth.permute(0, 2, 3, 1) # 1xHxWx1
# project pixels to camera frame
camera_matrix_tmp: torch.Tensor = camera_matrix[:, None, None] # Bx1x1x3x3
if not isinstance(points_2d, torch.Tensor):
raise TypeError(f"Input point_2d type is not a torch.Tensor. Got {type(points_2d)}")
if not isinstance(points_depth, torch.Tensor):
raise TypeError(f"Input depth type is not a torch.Tensor. Got {type(points_depth)}")
if not isinstance(camera_matrix, torch.Tensor):
raise TypeError(f"Input camera_matrix type is not a torch.Tensor. Got {type(camera_matrix)}")
if not (points_2d.device == points_depth.device == camera_matrix.device):
raise ValueError("Input tensors must be all in the same device.")
if not points_2d.shape[-1] == 2:
raise ValueError("Input points_2d must be in the shape of (*, 2)." " Got {}".format(points_2d.shape))
if not points_depth.shape[-1] == 1:
raise ValueError("Input depth must be in the shape of (*, 1)." " Got {}".format(points_depth.shape))
if not camera_matrix.shape[-2:] == (3, 3):
raise ValueError("Input camera_matrix must be in the shape of (*, 3, 3).")
# projection eq. K_inv * [u v 1]'
# x = (u - cx) * Z / fx
# y = (v - cy) * Z / fy
# unpack coordinates
u_coord: torch.Tensor = points_2d[..., 0]
v_coord: torch.Tensor = points_2d[..., 1]
# unpack intrinsics
fx: torch.Tensor = camera_matrix_tmp[..., 0, 0]
fy: torch.Tensor = camera_matrix_tmp[..., 1, 1]
cx: torch.Tensor = camera_matrix_tmp[..., 0, 2]
cy: torch.Tensor = camera_matrix_tmp[..., 1, 2]
# projective
x_coord: torch.Tensor = (u_coord - cx) / fx
y_coord: torch.Tensor = (v_coord - cy) / fy
xyz: torch.Tensor = torch.stack([x_coord, y_coord], dim=-1)
xyz = torch.nn.functional.pad(xyz, [0, 1], "constant", 1.0)
points_3d: torch.Tensor = xyz * points_depth
return points_3d.permute(0, 3, 1, 2) # Bx3xHxW
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment