Last active
October 19, 2023 21:13
-
-
Save sergeyprokudin/86eb10c345ab88f74d3d2fe4a6dcdddf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# following | |
# https://github.com/facebookresearch/pytorch3d/issues/35 | |
import torch | |
import torch.nn as nn | |
class PointsRendererDepth(nn.Module): | |
""" | |
A class for rendering a batch of points. The class should | |
be initialized with a rasterizer and compositor class which each have a forward | |
function. | |
""" | |
def __init__(self, rasterizer, compositor) -> None: | |
super().__init__() | |
self.rasterizer = rasterizer | |
self.compositor = compositor | |
def to(self, device): | |
# Manually move to device rasterizer as the cameras | |
# within the class are not of type nn.Module | |
self.rasterizer = self.rasterizer.to(device) | |
self.compositor = self.compositor.to(device) | |
return self | |
def forward(self, point_clouds, **kwargs) -> torch.Tensor: | |
fragments = self.rasterizer(point_clouds, **kwargs) | |
# Construct weights based on the distance of a point to the true point. | |
# However, this could be done differently: e.g. predicted as opposed | |
# to a function of the weights. | |
r = self.rasterizer.raster_settings.radius | |
dists2 = fragments.dists.permute(0, 3, 1, 2) | |
weights = 1 - dists2 / (r * r) | |
images = self.compositor( | |
fragments.idx.long().permute(0, 3, 1, 2), | |
weights, | |
point_clouds.features_packed().permute(1, 0), | |
**kwargs, | |
) | |
# permute so image comes at the end | |
images = images.permute(0, 2, 3, 1) | |
return images, fragments.zbuf | |
def rgb_depth_to_point_cloud_ortho(img_depth, img_rgb, R, T): | |
img_width, img_height = img_depth.shape | |
x = np.arange(0, img_width) | |
y = np.arange(0, img_height) | |
xs, ys = np.meshgrid(x, y) | |
xs = torch.Tensor(xs).to(device) | |
ys = torch.Tensor(ys).to(device) | |
uvz = torch.stack([-2*xs/img_width +1, -2*ys/img_height +1, img_depth.to(device)], 2).reshape([-1, 3]) | |
rgb = img_rgb.reshape([-1, 3]).to(device) | |
mask = torch.sum(rgb, 1)!=0 | |
rgb = rgb[mask] | |
uvz = uvz[mask] | |
xyz = uvz | |
return xyz, rgb | |
def rgb_depth_to_point_cloud(img_depth, img_rgb, cameras, device='cuda'): | |
x = np.linspace(-1, 1, img_width) | |
y = np.linspace(-1, 1, img_height) | |
xs, ys = np.meshgrid(x, y) | |
xs = torch.Tensor(xs).to(device) | |
ys = torch.Tensor(ys).to(device) | |
uvz = torch.stack([-xs, -ys, img_depth.to(device)], 2).reshape([-1, 3]) | |
rgb = img_rgb.reshape([-1, 3]).to(device) | |
mask = torch.sum(rgb, 1)!=0 | |
rgb = rgb[mask] | |
uvz = uvz[mask] | |
xyz = cameras.to(device).unproject_points(uvz.unsqueeze(0)) | |
return xyz, rgb |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment