Skip to content

Instantly share code, notes, and snippets.

@wangg12
Last active January 2, 2023 21:34
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wangg12/32250e352671c7f8c9a42ddb437142a6 to your computer and use it in GitHub Desktop.
Save wangg12/32250e352671c7f8c9a42ddb437142a6 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import numpy as np
from typing import Tuple
import torch
import torch.nn.functional as F
from pytorch3d.transforms import Rotate, Transform3d, Translate
from .utils import TensorProperties, convert_to_tensors_and_broadcast
# Default values for rotation and translation matrices.
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
class OpenGLRealPerspectiveCameras(TensorProperties):
"""
A class which stores a batch of parameters to generate a batch of
projection matrices using the OpenGL convention for a perspective camera.
The extrinsics of the camera (R and T matrices) can also be set in the
initializer or passed in to `get_full_projection_transform` to get
the full transformation from world -> screen.
The `transform_points` method calculates the full world -> screen transform
and then applies it to the input points.
The transforms can also be returned separately as Transform3d objects.
"""
def __init__(
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
znear=0.01,
zfar=100.0,
x0=0,
y0=0,
w=640,
h=480,
device="cpu",
):
"""
__init__(self, znear, zfar, R, T, device) -> None # noqa
Args:
znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum.
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
device: torch.device or string
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(
device=device,
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
znear=znear,
zfar=zfar,
x0=x0,
y0=y0,
h=h,
w=w,
)
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the OpenGL perpective projection matrix with a symmetric
viewing frustrum. Use column major order.
Args:
**kwargs: parameters for the projection can be passed in as keyword
arguments to override the default values set in `__init__`.
Return:
P: a Transform3d object which represents a batch of projection
matrices of shape (N, 3, 3)
.. code-block:: python
q = -(far + near)/(far - near)
qn = -2*far*near/(far-near)
P.T = [
[2*fx/w, 0, 0, 0],
[0, -2*fy/h, 0, 0],
[(2*px-w)/w, (-2*py+h)/h, -q, 1],
[0, 0, qn, 0],
]
sometimes P[2,:] *= -1, P[1, :] *= -1
"""
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
x0 = kwargs.get("x0", self.x0) # pyre-ignore[16]
y0 = kwargs.get("y0", self.y0) # pyre-ignore[16]
w = kwargs.get("w", self.w) # pyre-ignore[16]
h = kwargs.get("h", self.h) # pyre-ignore[16]
principal_point = kwargs.get(
"principal_point", self.principal_point
) # pyre-ignore[16]
focal_length = kwargs.get(
"focal_length", self.focal_length
) # pyre-ignore[16]
if not torch.is_tensor(focal_length):
focal_length = torch.tensor(focal_length, device=self.device)
if len(focal_length.shape) in (0, 1) or focal_length.shape[1] == 1:
fx = fy = focal_length
else:
fx, fy = focal_length.unbind(1)
if not torch.is_tensor(principal_point):
principal_point = torch.tensor(principal_point, device=self.device)
px, py = principal_point.unbind(1)
P = torch.zeros(
(self._N, 4, 4), device=self.device, dtype=torch.float32
)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
# NOTE: In OpenGL the projection matrix changes the handedness of the
# coordinate frame. i.e the NDC space postive z direction is the
# camera space negative z direction. This is because the sign of the z
# in the projection matrix is set to -1.0.
# In pytorch3d we maintain a right handed coordinate system throughout
# so the so the z sign is 1.0.
z_sign = 1.0
# define P.T directly
P[:, 0, 0] = 2.0 * fx / w
P[:, 1, 1] = -2.0 * fy / h
P[:, 2, 0] = -(-2 * px + w + 2 * x0) / w
P[:, 2, 1] = -(+2 * py - h + 2 * y0) / h
P[:, 2, 3] = z_sign * ones
# NOTE: This part of the matrix is for z renormalization in OpenGL
# which maps the z to [-1, 1]. This won't work yet as the torch3d
# rasterizer ignores faces which have z < 0.
# P[:, 2, 2] = z_sign * (far + near) / (far - near)
# P[:, 2, 3] = -2.0 * far * near / (far - near)
# P[:, 2, 3] = z_sign * torch.ones((N))
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
# is at the near clipping plane and z = 1 when the point is at the far
# clipping plane. This replaces the OpenGL z normalization to [-1, 1]
# until rasterization is changed to clip at z = -1.
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
P[:, 3, 2] = -(zfar * znear) / (zfar - znear)
# OpenGL uses column vectors so need to transpose the projection matrix
# as torch3d uses row vectors.
transform = Transform3d(device=self.device)
transform._matrix = P
return transform
def clone(self):
other = OpenGLRealPerspectiveCameras(device=self.device)
return super().clone(other)
def get_camera_center(self, **kwargs):
"""
Return the 3D location of the camera optical center
in the world coordinates.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting T here will update the values set in init as this
value may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
C: a batch of 3D locations of shape (N, 3) denoting
the locations of the center of each camera in the batch.
"""
w2v_trans = self.get_world_to_view_transform(**kwargs)
P = w2v_trans.inverse().get_matrix()
# the camera center is the translation component (the first 3 elements
# of the last row) of the inverted world-to-view
# transform (4x4 RT matrix)
C = P[:, 3, :3]
return C
def get_world_to_view_transform(self, **kwargs) -> Transform3d:
"""
Return the world-to-view transform.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
R = self.R = kwargs.get("R", self.R) # pyre-ignore[16]
T = self.T = kwargs.get("T", self.T) # pyre-ignore[16]
if T.shape[0] != R.shape[0]:
msg = "Expected R, T to have the same batch dimension; got %r, %r"
raise ValueError(msg % (R.shape[0], T.shape[0]))
if T.dim() != 2 or T.shape[1:] != (3,):
msg = "Expected T to have shape (N, 3); got %r"
raise ValueError(msg % repr(T.shape))
if R.dim() != 3 or R.shape[1:] != (3, 3):
msg = "Expected R to have shape (N, 3, 3); got %r"
raise ValueError(msg % R.shape)
# Create a Transform3d object
T = Translate(T, device=T.device)
R = Rotate(R, device=R.device)
world_to_view_transform = R.compose(T)
return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-screen transform composing the
world-to-view and view-to-screen transforms.
Args:
**kwargs: parameters for the projection transforms can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform)
def transform_points(self, points, **kwargs) -> torch.Tensor:
"""
Transform input points from world to screen space.
Args:
points: torch tensor of shape (..., 3).
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
return world_to_screen_transform.transform_points(points)
# test pytorch 3d renderer
# TODO: make this work
# render multi objects in batch, one in one image
import errno
import os
import os.path as osp
import sys
import time
import struct
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from transforms3d.axangles import axangle2mat
from transforms3d.euler import euler2quat, mat2euler, quat2euler
from transforms3d.quaternions import axangle2quat, mat2quat, qinverse, qmult
# io utils
# from pytorch3d.io import load_obj, load_ply
# rendering components
from pytorch3d.renderer import (BlendParams, MeshRasterizer, MeshRenderer,
OpenGLPerspectiveCameras, PhongShader,
PointLights, RasterizationSettings,
SilhouetteShader, look_at_rotation,
look_at_view_transform)
# from pytorch3d.renderer.cameras import SfMPerspectiveCameras
from pytorch3d.renderer.cameras_real import OpenGLRealPerspectiveCameras
# datastructures
from pytorch3d.structures import Meshes, Textures, list_to_padded
# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate
cur_dir = osp.dirname(osp.abspath(__file__))
sys.path.append(osp.join(cur_dir, '../'))
data_dir = osp.join(cur_dir, '../datasets/')
output_directory = osp.join(cur_dir, '../output/results')
output_directory_ren = osp.join(output_directory, 'p3d')
os.makedirs(output_directory_ren, exist_ok=True)
ply_model_root = osp.join(data_dir, "BOP_DATASETS/lm/models")
HEIGHT = 480
WIDTH = 640
IMG_SIZE = 640
ZNEAR = 0.01
ZFAR = 10.0
K = np.array([[572.4114, 0, 325.2611], [0, 573.57043, 242.04899], [0, 0, 1]])
objects = ["ape", "benchvise", "bowl", "camera", "can", "cat",
"cup", "driller", "duck", "eggbox", "glue", "holepuncher", "iron", "lamp", "phone"]
id2obj = {
1: "ape",
2: "benchvise",
3: "bowl",
4: "camera",
5: "can",
6: "cat",
7: "cup",
8: "driller",
9: "duck",
10: "eggbox",
11: "glue",
12: "holepuncher",
13: "iron",
14: "lamp",
15: "phone",
}
obj_num = len(id2obj)
obj2id = {_name: _id for _id, _name in id2obj.items()}
def load_ply(path, vertex_scale=1.0):
# https://github.com/thodan/sixd_toolkit/blob/master/pysixd/inout.py
# bop_toolkit
"""Loads a 3D mesh model from a PLY file.
:param path: Path to a PLY file.
:return: The loaded model given by a dictionary with items:
-' pts' (nx3 ndarray),
- 'normals' (nx3 ndarray), optional
- 'colors' (nx3 ndarray), optional
- 'faces' (mx3 ndarray), optional.
- 'texture_uv' (nx2 ndarray), optional
- 'texture_uv_face' (mx6 ndarray), optional
- 'texture_file' (string), optional
"""
f = open(path, "r")
# Only triangular faces are supported.
face_n_corners = 3
n_pts = 0
n_faces = 0
pt_props = []
face_props = []
is_binary = False
header_vertex_section = False
header_face_section = False
texture_file = None
# Read the header.
while True:
# Strip the newline character(s)
line = f.readline()
if isinstance(line, str):
line = line.rstrip("\n").rstrip("\r")
else:
line = str(line, 'utf-8').rstrip("\n").rstrip("\r")
if line.startswith('comment TextureFile'):
texture_file = line.split()[-1]
elif line.startswith("element vertex"):
n_pts = int(line.split()[-1])
header_vertex_section = True
header_face_section = False
elif line.startswith("element face"):
n_faces = int(line.split()[-1])
header_vertex_section = False
header_face_section = True
elif line.startswith("element"): # Some other element.
header_vertex_section = False
header_face_section = False
elif line.startswith("property") and header_vertex_section:
# (name of the property, data type)
prop_name = line.split()[-1]
if prop_name == "s":
prop_name = "texture_u"
if prop_name == "t":
prop_name = "texture_v"
prop_type = line.split()[-2]
pt_props.append((prop_name, prop_type))
elif line.startswith("property list") and header_face_section:
elems = line.split()
if elems[-1] == "vertex_indices" or elems[-1] == 'vertex_index':
# (name of the property, data type)
face_props.append(("n_corners", elems[2]))
for i in range(face_n_corners):
face_props.append(("ind_" + str(i), elems[3]))
elif elems[-1] == 'texcoord':
# (name of the property, data type)
face_props.append(('texcoord', elems[2]))
for i in range(face_n_corners * 2):
face_props.append(('texcoord_ind_' + str(i), elems[3]))
else:
print("Warning: Not supported face property: " + elems[-1])
elif line.startswith("format"):
if "binary" in line:
is_binary = True
elif line.startswith("end_header"):
break
# Prepare data structures.
model = {}
if texture_file is not None:
model['texture_file'] = texture_file
model["pts"] = np.zeros((n_pts, 3), np.float)
if n_faces > 0:
model["faces"] = np.zeros((n_faces, face_n_corners), np.float)
# print(pt_props)
pt_props_names = [p[0] for p in pt_props]
face_props_names = [p[0] for p in face_props]
# print(pt_props_names)
is_normal = False
if {"nx", "ny", "nz"}.issubset(set(pt_props_names)):
is_normal = True
model["normals"] = np.zeros((n_pts, 3), np.float)
is_color = False
if {"red", "green", "blue"}.issubset(set(pt_props_names)):
is_color = True
model["colors"] = np.zeros((n_pts, 3), np.float)
is_texture_pt = False
if {"texture_u", "texture_v"}.issubset(set(pt_props_names)):
is_texture_pt = True
model["texture_uv"] = np.zeros((n_pts, 2), np.float)
is_texture_face = False
if {'texcoord'}.issubset(set(face_props_names)):
is_texture_face = True
model['texture_uv_face'] = np.zeros((n_faces, 6), np.float)
# Formats for the binary case.
formats = {
"float": ("f", 4),
"double": ("d", 8),
"int": ("i", 4),
"uchar": ("B", 1),
}
# Load vertices.
for pt_id in range(n_pts):
prop_vals = {}
load_props = ["x", "y", "z", "nx", "ny", "nz",
"red", "green", "blue", "texture_u", "texture_v"]
if is_binary:
for prop in pt_props:
format = formats[prop[1]]
read_data = f.read(format[1])
val = struct.unpack(format[0], read_data)[0]
if prop[0] in load_props:
prop_vals[prop[0]] = val
else:
elems = f.readline().rstrip("\n").rstrip("\r").split()
for prop_id, prop in enumerate(pt_props):
if prop[0] in load_props:
prop_vals[prop[0]] = elems[prop_id]
model["pts"][pt_id, 0] = float(prop_vals["x"])
model["pts"][pt_id, 1] = float(prop_vals["y"])
model["pts"][pt_id, 2] = float(prop_vals["z"])
if is_normal:
model["normals"][pt_id, 0] = float(prop_vals["nx"])
model["normals"][pt_id, 1] = float(prop_vals["ny"])
model["normals"][pt_id, 2] = float(prop_vals["nz"])
if is_color:
model["colors"][pt_id, 0] = float(prop_vals["red"])
model["colors"][pt_id, 1] = float(prop_vals["green"])
model["colors"][pt_id, 2] = float(prop_vals["blue"])
if is_texture_pt:
model["texture_uv"][pt_id, 0] = float(prop_vals["texture_u"])
model["texture_uv"][pt_id, 1] = float(prop_vals["texture_v"])
# Load faces.
for face_id in range(n_faces):
prop_vals = {}
if is_binary:
for prop in face_props:
format = formats[prop[1]]
val = struct.unpack(format[0], f.read(format[1]))[0]
if prop[0] == "n_corners":
if val != face_n_corners:
raise ValueError("Only triangular faces are supported.")
# print("Number of face corners: " + str(val))
# exit(-1)
elif prop[0] == 'texcoord':
if val != face_n_corners * 2:
raise ValueError('Wrong number of UV face coordinates.')
else:
prop_vals[prop[0]] = val
else:
elems = f.readline().rstrip("\n").rstrip("\r").split()
for prop_id, prop in enumerate(face_props):
if prop[0] == "n_corners":
if int(elems[prop_id]) != face_n_corners:
raise ValueError("Only triangular faces are supported.")
elif prop[0] == 'texcoord':
if int(elems[prop_id]) != face_n_corners * 2:
raise ValueError('Wrong number of UV face coordinates.')
else:
prop_vals[prop[0]] = elems[prop_id]
model["faces"][face_id, 0] = int(prop_vals["ind_0"])
model["faces"][face_id, 1] = int(prop_vals["ind_1"])
model["faces"][face_id, 2] = int(prop_vals["ind_2"])
if is_texture_face:
for i in range(6):
model['texture_uv_face'][face_id, i] = float(
prop_vals['texcoord_ind_{}'.format(i)])
f.close()
model['pts'] *= vertex_scale
return model
def grid_show(ims, titles=None, row=1, col=3, dpi=200, save_path=None, title_fontsize=5, show=True):
if row * col < len(ims):
print('_____________row*col < len(ims)___________')
col = int(np.ceil(len(ims) / row))
fig = plt.figure(dpi=dpi, figsize=plt.figaspect(row / float(col)))
k = 0
for i in range(row):
for j in range(col):
plt.subplot(row, col, k + 1)
plt.axis('off')
plt.imshow(ims[k])
if titles is not None:
# plt.title(titles[k], size=title_fontsize)
plt.text(0.5, 1.08, titles[k],
horizontalalignment='center',
fontsize=title_fontsize,
transform=plt.gca().transAxes)
k += 1
if k == len(ims):
break
# plt.tight_layout()
if show:
plt.show()
else:
if save_path is not None:
mkdir_p(osp.dirname(save_path))
plt.savefig(save_path)
return fig
def mkdir_p(dirname):
"""Like "mkdir -p", make a dir recursively, but do nothing if the dir
exists.
Args:
dirname(str):
"""
assert dirname is not None
if dirname == "" or os.path.isdir(dirname):
return
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
def print_stat(data, name=""):
print(name, "min", data.min(), "max", data.max(),
"mean", data.mean(), "std", data.std(),
"sum", data.sum(), "shape", data.shape)
###################################################################################################
def load_ply_models(model_paths, device='cuda', dtype=torch.float32, vertex_scale=0.001):
ply_models = [load_ply(ply_path, vertex_scale=vertex_scale) for ply_path in model_paths]
verts = [torch.tensor(m['pts'], device=device, dtype=dtype) for m in ply_models]
faces = [torch.tensor(m['faces'], device=device, dtype=torch.int64) for m in ply_models]
for m in ply_models:
if m['colors'].max() > 1.1:
m['colors'] /= 255.0
verts_rgb_list = [torch.tensor(m['colors'], device=device, dtype=dtype) for m in ply_models] # [V,3]
res_models = []
for i in range(len(ply_models)):
model = {}
model['verts'] = verts[i]
model['faces'] = faces[i]
model['verts_rgb'] = verts_rgb_list[i]
res_models.append(model)
return res_models
def main():
# Set the cuda device
device = torch.device("cuda:0")
torch.cuda.set_device(device)
###########################
# load objects
###########################
objs = objects
np.array([[-5.87785252e-01, 8.09016994e-01, 0.00000000e+00], [-4.95380036e-17, -3.59914664e-17, -1.00000000e+00], [-8.09016994e-01, -5.87785252e-01, 6.12323400e-17]])
# obj_paths = [osp.join(model_root, '{}/textured.obj'.format(cls_name)) for cls_name in objs]
# texture_paths = [osp.join(model_root, '{}/texture_map.png'.format(cls_name)) for cls_name in objs]
ply_paths = [osp.join(ply_model_root, "obj_{:06d}.ply".format(obj2id[cls_name]))
for cls_name in objs]
models = load_ply_models(ply_paths, vertex_scale=0.001)
cameras = OpenGLRealPerspectiveCameras(
focal_length=((K[0,0], K[1,1]),), # Nx2
principal_point=((K[0,2], K[1,2]),), # Nx2
x0=0,
y0=0,
w=WIDTH,
h=WIDTH, #HEIGHT,
znear=ZNEAR,
zfar=ZFAR,
device=device,
)
# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of
# edges. Refer to blending.py for more details.
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
# Define the settings for rasterization and shading. Here we set the output image to be of size
# 640x640. To form the blended image we use 100 faces for each pixel. Refer to rasterize_meshes.py
# for an explanation of this parameter.
silhouette_raster_settings = RasterizationSettings(
image_size=IMG_SIZE, # longer side or scaled longer side
blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
faces_per_pixel=100, # the nearest faces_per_pixel points along the z-axis.
bin_size=0
)
# Create a silhouette mesh renderer by composing a rasterizer and a shader.
silhouette_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=silhouette_raster_settings
),
shader=SilhouetteShader(blend_params=blend_params)
)
# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
phong_raster_settings = RasterizationSettings(
image_size=IMG_SIZE,
blur_radius=0.0,
faces_per_pixel=1,
bin_size=0
)
# We can add a point light in front of the object.
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=phong_raster_settings
),
shader=PhongShader(device=device, lights=lights)
)
# pose =============================================
R1 = axangle2mat((1, 0, 0), angle=0.5 * np.pi)
R2 = axangle2mat((0, 0, 1), angle=-0.7 * np.pi)
R = np.dot(R1, R2)
print("R det", torch.det(torch.tensor(R)))
quat = mat2quat(R)
t = np.array([-0.1, 0.1, 0.7], dtype=np.float32)
t2 = np.array([0.1, 0.1, 0.7], dtype=np.float32)
t3 = np.array([-0.1, -0.1, 0.7], dtype=np.float32)
t4 = np.array([0.1, -0.1, 0.7], dtype=np.float32)
t5 = np.array([0, 0.1, 0.7], dtype=np.float32)
batch_size = 3
Rs = [R, R.copy(), R.copy(), R.copy(), R.copy()][:batch_size]
print("R", R)
quats = [quat, quat.copy(), quat.copy(), quat.copy(), quat.copy()][:batch_size]
ts = [t, t2, t3, t4, t5][:batch_size]
runs = 100
t_render = 0
for i in tqdm(range(runs)):
t_render_start = time.perf_counter()
obj_ids = np.random.randint(0, len(objs), size=len(quats))
# Render the objs providing the values of R and T.
batch_verts_rgb = list_to_padded([models[obj_id]['verts_rgb'] for obj_id in obj_ids]) # B, Vmax, 3
batch_textures = Textures(verts_rgb=batch_verts_rgb.to(device))
batch_mesh = Meshes(
verts=[models[obj_id]['verts'] for obj_id in obj_ids],
faces=[models[obj_id]['faces'] for obj_id in obj_ids],
textures=batch_textures,
)
batch_R = torch.tensor(np.stack(Rs), device=device, dtype=torch.float32).permute(0,2,1) # Bx3x3
batch_T = torch.tensor(np.stack(ts), device=device, dtype=torch.float32) # Bx3
silhouete = silhouette_renderer(meshes_world=batch_mesh, R=batch_R, T=batch_T)
image_ref = phong_renderer(meshes_world=batch_mesh, R=batch_R, T=batch_T)
# crop results
silhouete = silhouete[:, :HEIGHT, :WIDTH, :].cpu().numpy()
image_ref = image_ref[:, :HEIGHT, :WIDTH, :3].cpu().numpy()
t_render += time.perf_counter() - t_render_start
if True:
pred_images = image_ref
for i in range(pred_images.shape[0]):
pred_mask = silhouete[i, :, :, 3].astype('float32')
print("num rendered images", pred_images.shape[0])
image = pred_images[i]
print('image', image.shape)
print('dr mask area: ', pred_mask.sum())
print_stat(pred_mask, "pred_mask")
show_ims = [image, pred_mask]
show_titles = ['image', 'mask']
grid_show(show_ims, show_titles, row=1, col=2)
print("runs: {}, {:.3f}fps, {:.3f}ms/im".format(runs, runs / t_render, t_render / runs * 1000))
if __name__ == '__main__':
main()
@ldepn
Copy link

ldepn commented Apr 2, 2020

你好,你这个代码是在pytorch3d哪个版本基础上修改的?我用的最新版本,里面缺失一些函数,例如from .utils import TensorProperties, convert_to_tensors_and_broadcast

@densechen
Copy link

Hi,
Thanks for providing so nice code! Can you tell me that what is the function of x0 and y0 in the init model?
I use this code to rendering same images, but I found that it can align to the real camera captured one,
The following is observed one:
截屏2020-04-16 下午6 48 42

and the following is rendered one:
截屏2020-04-16 下午6 48 57

It is obvious that there exists some pixel shifts, could you please give me some help on this?
Best
Chen

@wangg12
Copy link
Author

wangg12 commented Apr 16, 2020

please attach your code and data if possible, otherwise I can not help.

@densechen
Copy link

Hi, Wang,
Thanks for your reply!
The settings is:
`
class settings():

HEIGHT = 480
WIDTH = 640
IMG_SIZE = 640
ZNEAR = 0.01
ZFAR = 10.0
K = np.array(
    [[1077.8360, 0, 323.7872],
     [0, 1078.1890, 279.1890],
     [0, 0, 1],
     ])

fov_x, fov_y = intrinsics_fov(
    w=WIDTH, h=HEIGHT, fx=K[0, 0], fy=K[1, 1], cx=K[0, 2], cy=K[1, 2])

num_points = 512

device = torch.device("cuda:1")
milestone = [30, ]

`

And I used the YCB dataset,
rotation = np.array( (meta["poses"][:3, :3, idx].T @ np.array( [[-1, 0, 0], [0, 1, 0], [0, 0, 1]] )), dtype=np.float32 ) translation = np.array(meta["poses"][:3, 3, idx] * np.array( [-1, 1, 1]), dtype=np.float32)

And define the render by:
`#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import errno
from pytorch3d.transforms import Rotate, Translate
from pytorch3d.structures import Meshes, Textures, list_to_padded
from pytorch3d.renderer import (BlendParams, MeshRasterizer, OpenGLPerspectiveCameras, TexturedSoftPhongShader,
PointLights, RasterizationSettings, SoftSilhouetteShader, look_at_rotation, look_at_view_transform)
from render.mesh_render import MeshRenderer
from render.real_camera import OpenGLRealPerspectiveCameras
from pytorch3d.io import load_obj, load_ply, load_objs_as_meshes
from transforms3d.quaternions import axangle2quat, mat2quat, qinverse, qmult
from transforms3d.euler import euler2quat, mat2euler, quat2euler
from transforms3d.axangles import axangle2mat
from tqdm import tqdm
import matplotlib.pyplot as plt
import struct
import time
import sys
import os
import math
import numpy as np
from typing import Tuple
import torch
import torch.nn.functional as F

from pytorch3d.transforms import Rotate, Transform3d, Translate

from pytorch3d.renderer.utils import TensorProperties, convert_to_tensors_and_broadcast
from config.settings import settings
from pytorch3d.renderer.cameras import SfMPerspectiveCameras

def define_camera():
# define camera
cameras = OpenGLRealPerspectiveCameras(
focal_length=((settings.K[0, 0], settings.K[1, 1]),), # Nx2
principal_point=((settings.K[0, 2], settings.K[1, 2]), ), # Nx2
x0=0,
y0=0,
w=settings.WIDTH,
h=settings.WIDTH,
znear=settings.ZNEAR,
zfar=settings.ZFAR,
device=settings.device
)

# cameras = OpenGLPerspectiveCameras(
#     znear=settings.ZNEAR,
#     zfar=settings.ZFAR,
#     aspect_ratio=settings.fov_x/settings.fov_y,
#     fov=settings.fov_x,
#     device=settings.device
# )

# cameras = SfMPerspectiveCameras(
#     focal_length=((settings.K[0, 0], settings.K[1, 1]),),  # Nx2
#     principal_point=((settings.K[0, 2], settings.K[1, 2]), ),  # Nx2
#     device=settings.device
# )

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of edges. Refer to blending.py for more details.
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size 640x640. To form the blended image we use 100 faces for each pixel. Refer to rasterize_meshes.py for an explanation of this parameter.
silhouette_raster_settings = RasterizationSettings(
    image_size=settings.IMG_SIZE,  # longer side or scaled longer side
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
    # The nearest faces_per_pixel points along the z-axis.
    faces_per_pixel=100,
    bin_size=0
)
# Create a silhouette mesh renderer by composing a rasterizer and a shader
silhouete_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=silhouette_raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
).to(settings.device)
# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
phong_raster_settings = RasterizationSettings(
    image_size=settings.IMG_SIZE,
    blur_radius=0.0,
    faces_per_pixel=1,
    bin_size=0
)
# We can add a point light in front of the object.
lights = PointLights(device=settings.device, location=((2.0, 2.0, -2.0,),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=phong_raster_settings
    ),
    shader=TexturedSoftPhongShader(device=settings.device, lights=lights)
).to(settings.device)

return silhouete_renderer, phong_renderer

Then, get the render image: silhouete, _ = silhouete_renderer(
meshes_world=mesh, R=rotation, T=translation)
image_ref, depth_ref = phong_renderer(meshes_world=mesh, R=rotation, T=translation)
`

I still don't konw what was wrong here, looking forward you.
Best

@wangg12
Copy link
Author

wangg12 commented Apr 16, 2020

have you successfully rendered it with other methods like opengl?

@densechen
Copy link

Not yet.
But the result seems well expect that with a few pixels(almost 15~30 pixel) shift. So, I wander the parameter x0, y0 could help ?

@densechen
Copy link

Also, I have rendered the color image. It is also some pixels shifted

@wangg12
Copy link
Author

wangg12 commented Apr 16, 2020

I would recommend debugging by comparing it to other rendering methods besides the ground-truth segmentation.

@densechen
Copy link

Thanks for your advise. I will have a try on it

@EphChem
Copy link

EphChem commented Apr 20, 2020

Hi @wangg12, thank you for making your code available! I am using your camera class and it works for rendering silhouette but when I use zbuf to get depth, the depth map is flipped along the y-axis. In the pytorch3d repo they initially had this issue and have fixed here. Is your code adapted to this change?

@wangg12
Copy link
Author

wangg12 commented Apr 20, 2020

I haven't look into that yet. It would be nice if you can include it into this and propose a fix.

@EphChem
Copy link

EphChem commented Apr 20, 2020

ok, I'm looking into this today. As a quick fix, I set y-axis in rotation matrix to negative for depth rasterizer: R[:,:,1]=-R[:,:,1] . This flips the object but now there is an offset in the y-axis. See screen pic below:
image

@wangg12
Copy link
Author

wangg12 commented Apr 20, 2020

Does your R contain also translation? Is it 3x3 or 4x4?

@EphChem
Copy link

EphChem commented Apr 20, 2020

My R is 3x3 so no translation

@ForrestPi
Copy link

Thanks your codes, I have anothor problem, I have to train a batch of cropped images, So the FOV and optical center are different,So how to set perspective camera?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment