Skip to content

Instantly share code, notes, and snippets.

@sergeyprokudin
Last active October 19, 2023 21:13
Show Gist options
  • Save sergeyprokudin/86eb10c345ab88f74d3d2fe4a6dcdddf to your computer and use it in GitHub Desktop.
Save sergeyprokudin/86eb10c345ab88f74d3d2fe4a6dcdddf to your computer and use it in GitHub Desktop.
# following
# https://github.com/facebookresearch/pytorch3d/issues/35
import torch
import torch.nn as nn
class PointsRendererDepth(nn.Module):
"""
A class for rendering a batch of points. The class should
be initialized with a rasterizer and compositor class which each have a forward
function.
"""
def __init__(self, rasterizer, compositor) -> None:
super().__init__()
self.rasterizer = rasterizer
self.compositor = compositor
def to(self, device):
# Manually move to device rasterizer as the cameras
# within the class are not of type nn.Module
self.rasterizer = self.rasterizer.to(device)
self.compositor = self.compositor.to(device)
return self
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = self.rasterizer.raster_settings.radius
dists2 = fragments.dists.permute(0, 3, 1, 2)
weights = 1 - dists2 / (r * r)
images = self.compositor(
fragments.idx.long().permute(0, 3, 1, 2),
weights,
point_clouds.features_packed().permute(1, 0),
**kwargs,
)
# permute so image comes at the end
images = images.permute(0, 2, 3, 1)
return images, fragments.zbuf
def rgb_depth_to_point_cloud_ortho(img_depth, img_rgb, R, T):
img_width, img_height = img_depth.shape
x = np.arange(0, img_width)
y = np.arange(0, img_height)
xs, ys = np.meshgrid(x, y)
xs = torch.Tensor(xs).to(device)
ys = torch.Tensor(ys).to(device)
uvz = torch.stack([-2*xs/img_width +1, -2*ys/img_height +1, img_depth.to(device)], 2).reshape([-1, 3])
rgb = img_rgb.reshape([-1, 3]).to(device)
mask = torch.sum(rgb, 1)!=0
rgb = rgb[mask]
uvz = uvz[mask]
xyz = uvz
return xyz, rgb
def rgb_depth_to_point_cloud(img_depth, img_rgb, cameras, device='cuda'):
x = np.linspace(-1, 1, img_width)
y = np.linspace(-1, 1, img_height)
xs, ys = np.meshgrid(x, y)
xs = torch.Tensor(xs).to(device)
ys = torch.Tensor(ys).to(device)
uvz = torch.stack([-xs, -ys, img_depth.to(device)], 2).reshape([-1, 3])
rgb = img_rgb.reshape([-1, 3]).to(device)
mask = torch.sum(rgb, 1)!=0
rgb = rgb[mask]
uvz = uvz[mask]
xyz = cameras.to(device).unproject_points(uvz.unsqueeze(0))
return xyz, rgb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment