-
-
Save Erol444/0a9f4ae505ef9208edb144e0237f1050 to your computer and use it in GitHub Desktop.
Kornia depth_to_3d function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def depth_to_3d(depth: torch.Tensor, camera_matrix: torch.Tensor) -> torch.Tensor: | |
"""Compute a 3d point per pixel given its depth value and the camera intrinsics. | |
Args: | |
depth: image tensor containing a depth value per pixel with shape :math:`(B, 1, H, W)`. | |
camera_matrix: tensor containing the camera intrinsics with shape :math:`(B, 3, 3)`. | |
normalize_points: whether to normalise the pointcloud. This must be set to `True` when the depth is | |
represented as the Euclidean ray length from the camera position. | |
Return: | |
tensor with a 3d point per pixel of the same resolution as the input :math:`(B, 3, H, W)`. | |
Example: | |
>>> depth = torch.rand(1, 1, 4, 4) | |
>>> K = torch.eye(3)[None] | |
>>> depth_to_3d(depth, K).shape | |
torch.Size([1, 3, 4, 4]) | |
""" | |
if not isinstance(depth, torch.Tensor): | |
raise TypeError(f"Input depht type is not a torch.Tensor. Got {type(depth)}.") | |
if not (len(depth.shape) == 4 and depth.shape[-3] == 1): | |
raise ValueError(f"Input depth musth have a shape (B, 1, H, W). Got: {depth.shape}") | |
if not isinstance(camera_matrix, torch.Tensor): | |
raise TypeError(f"Input camera_matrix type is not a torch.Tensor. " f"Got {type(camera_matrix)}.") | |
if not (len(camera_matrix.shape) == 3 and camera_matrix.shape[-2:] == (3, 3)): | |
raise ValueError(f"Input camera_matrix must have a shape (B, 3, 3). " f"Got: {camera_matrix.shape}.") | |
# create base coordinates grid | |
_, _, height, width = depth.shape | |
xs: torch.Tensor = torch.linspace(0, width - 1, width, dtype=torch.float32) | |
ys: torch.Tensor = torch.linspace(0, height - 1, height, dtype=torch.float32) | |
# generate grid by stacking coordinates | |
base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=-1) # WxHx2 | |
points_2d: torch.Tensor = base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2 | |
points_2d = points_2d.to(depth.device).to(depth.dtype) | |
# depth should come in Bx1xHxW | |
points_depth: torch.Tensor = depth.permute(0, 2, 3, 1) # 1xHxWx1 | |
# project pixels to camera frame | |
camera_matrix_tmp: torch.Tensor = camera_matrix[:, None, None] # Bx1x1x3x3 | |
if not isinstance(points_2d, torch.Tensor): | |
raise TypeError(f"Input point_2d type is not a torch.Tensor. Got {type(points_2d)}") | |
if not isinstance(points_depth, torch.Tensor): | |
raise TypeError(f"Input depth type is not a torch.Tensor. Got {type(points_depth)}") | |
if not isinstance(camera_matrix, torch.Tensor): | |
raise TypeError(f"Input camera_matrix type is not a torch.Tensor. Got {type(camera_matrix)}") | |
if not (points_2d.device == points_depth.device == camera_matrix.device): | |
raise ValueError("Input tensors must be all in the same device.") | |
if not points_2d.shape[-1] == 2: | |
raise ValueError("Input points_2d must be in the shape of (*, 2)." " Got {}".format(points_2d.shape)) | |
if not points_depth.shape[-1] == 1: | |
raise ValueError("Input depth must be in the shape of (*, 1)." " Got {}".format(points_depth.shape)) | |
if not camera_matrix.shape[-2:] == (3, 3): | |
raise ValueError("Input camera_matrix must be in the shape of (*, 3, 3).") | |
# projection eq. K_inv * [u v 1]' | |
# x = (u - cx) * Z / fx | |
# y = (v - cy) * Z / fy | |
# unpack coordinates | |
u_coord: torch.Tensor = points_2d[..., 0] | |
v_coord: torch.Tensor = points_2d[..., 1] | |
# unpack intrinsics | |
fx: torch.Tensor = camera_matrix_tmp[..., 0, 0] | |
fy: torch.Tensor = camera_matrix_tmp[..., 1, 1] | |
cx: torch.Tensor = camera_matrix_tmp[..., 0, 2] | |
cy: torch.Tensor = camera_matrix_tmp[..., 1, 2] | |
# projective | |
x_coord: torch.Tensor = (u_coord - cx) / fx | |
y_coord: torch.Tensor = (v_coord - cy) / fy | |
xyz: torch.Tensor = torch.stack([x_coord, y_coord], dim=-1) | |
xyz = torch.nn.functional.pad(xyz, [0, 1], "constant", 1.0) | |
points_3d: torch.Tensor = xyz * points_depth | |
return points_3d.permute(0, 3, 1, 2) # Bx3xHxW |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment