Created
November 16, 2021 04:23
-
-
Save woo1/210e7480f4cd65b69e426fae09a9b46f to your computer and use it in GitHub Desktop.
PARE smpl output pickle data to smplx using frankmocap module
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import sys | |
import os | |
sys.path.append(os.getcwd()) | |
from bodymocap.body_mocap_api import BodyMocap | |
import pickle | |
import joblib | |
import os.path as osp | |
import trimesh | |
from smplx import SMPL as _SMPL | |
from bodymocap import constants | |
from smplx.lbs import vertices2joints | |
from smplx.utils import SMPLOutput | |
class SMPL(_SMPL): | |
""" Extension of the official SMPL implementation to support more joints """ | |
def __init__(self, *args, **kwargs): | |
super(SMPL, self).__init__(*args, **kwargs) | |
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] | |
JOINT_REGRESSOR_TRAIN_EXTRA = 'extra_data/body_module/data_from_spin//J_regressor_extra.npy' | |
J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) | |
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) | |
self.joint_map = torch.tensor(joints, dtype=torch.long) | |
def forward(self, *args, **kwargs): | |
kwargs['get_skin'] = True | |
smpl_output = super(SMPL, self).forward(*args, **kwargs) | |
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) | |
joints = torch.cat([smpl_output.joints, extra_joints], dim=1) | |
joints = joints[:, self.joint_map, :] | |
output = SMPLOutput(vertices=smpl_output.vertices, | |
global_orient=smpl_output.global_orient, | |
body_pose=smpl_output.body_pose, | |
joints=joints, | |
betas=smpl_output.betas, | |
full_pose=smpl_output.full_pose) | |
return output | |
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): | |
""" | |
This function is borrowed from https://github.com/kornia/kornia | |
Convert 3x4 rotation matrix to 4d quaternion vector | |
This algorithm is based on algorithm described in | |
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 | |
Args: | |
rotation_matrix (Tensor): the rotation matrix to convert. | |
Return: | |
Tensor: the rotation in quaternion | |
Shape: | |
- Input: :math:`(N, 3, 4)` | |
- Output: :math:`(N, 4)` | |
Example: | |
>>> input = torch.rand(4, 3, 4) # Nx3x4 | |
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 | |
""" | |
if not torch.is_tensor(rotation_matrix): | |
raise TypeError("Input type is not a torch.Tensor. Got {}".format( | |
type(rotation_matrix))) | |
if len(rotation_matrix.shape) > 3: | |
raise ValueError( | |
"Input size must be a three dimensional tensor. Got {}".format( | |
rotation_matrix.shape)) | |
if not rotation_matrix.shape[-2:] == (3, 4): | |
raise ValueError( | |
"Input size must be a N x 3 x 4 tensor. Got {}".format( | |
rotation_matrix.shape)) | |
rmat_t = torch.transpose(rotation_matrix, 1, 2) | |
mask_d2 = rmat_t[:, 2, 2] < eps | |
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] | |
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] | |
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] | |
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], | |
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], | |
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) | |
t0_rep = t0.repeat(4, 1).t() | |
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] | |
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], | |
rmat_t[:, 0, 1] + rmat_t[:, 1, 0], | |
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) | |
t1_rep = t1.repeat(4, 1).t() | |
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] | |
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], | |
rmat_t[:, 2, 0] + rmat_t[:, 0, 2], | |
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) | |
t2_rep = t2.repeat(4, 1).t() | |
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] | |
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], | |
rmat_t[:, 2, 0] - rmat_t[:, 0, 2], | |
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) | |
t3_rep = t3.repeat(4, 1).t() | |
mask_c0 = mask_d2 * mask_d0_d1 | |
mask_c1 = mask_d2 * ~mask_d0_d1 | |
mask_c2 = ~mask_d2 * mask_d0_nd1 | |
mask_c3 = ~mask_d2 * ~mask_d0_nd1 | |
mask_c0 = mask_c0.view(-1, 1).type_as(q0) | |
mask_c1 = mask_c1.view(-1, 1).type_as(q1) | |
mask_c2 = mask_c2.view(-1, 1).type_as(q2) | |
mask_c3 = mask_c3.view(-1, 1).type_as(q3) | |
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 | |
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa | |
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa | |
q *= 0.5 | |
return q | |
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: | |
""" | |
This function is borrowed from https://github.com/kornia/kornia | |
Convert quaternion vector to angle axis of rotation. | |
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h | |
Args: | |
quaternion (torch.Tensor): tensor with quaternions. | |
Return: | |
torch.Tensor: tensor with angle axis of rotation. | |
Shape: | |
- Input: :math:`(*, 4)` where `*` means, any number of dimensions | |
- Output: :math:`(*, 3)` | |
Example: | |
>>> quaternion = torch.rand(2, 4) # Nx4 | |
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 | |
""" | |
if not torch.is_tensor(quaternion): | |
raise TypeError("Input type is not a torch.Tensor. Got {}".format( | |
type(quaternion))) | |
if not quaternion.shape[-1] == 4: | |
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" | |
.format(quaternion.shape)) | |
# unpack input and compute conversion | |
q1: torch.Tensor = quaternion[..., 1] | |
q2: torch.Tensor = quaternion[..., 2] | |
q3: torch.Tensor = quaternion[..., 3] | |
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 | |
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) | |
cos_theta: torch.Tensor = quaternion[..., 0] | |
two_theta: torch.Tensor = 2.0 * torch.where( | |
cos_theta < 0.0, | |
torch.atan2(-sin_theta, -cos_theta), | |
torch.atan2(sin_theta, cos_theta)) | |
k_pos: torch.Tensor = two_theta / sin_theta | |
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) | |
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) | |
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] | |
angle_axis[..., 0] += q1 * k | |
angle_axis[..., 1] += q2 * k | |
angle_axis[..., 2] += q3 * k | |
return angle_axis | |
def rotation_matrix_to_angle_axis(rotation_matrix): | |
""" | |
Convert 3x4 rotation matrix to Rodrigues vector | |
Args: | |
rotation_matrix (Tensor): rotation matrix. | |
Returns: | |
Tensor: Rodrigues vector transformation. | |
Shape: | |
- Input: :math:`(N, 3, 4)` | |
- Output: :math:`(N, 3)` | |
Example: | |
>>> input = torch.rand(2, 3, 4) # Nx4x4 | |
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 | |
""" | |
if rotation_matrix.shape[1:] == (3,3): | |
hom_mat = torch.tensor([0, 0, 1]).float() | |
rot_mat = rotation_matrix.reshape(-1, 3, 3) | |
batch_size, device = rot_mat.shape[0], rot_mat.device | |
hom_mat = hom_mat.view(1, 3, 1) | |
hom_mat = hom_mat.repeat(batch_size, 1, 1).contiguous() | |
hom_mat = hom_mat.to(device) | |
rotation_matrix = torch.cat([rot_mat, hom_mat], dim=-1) | |
quaternion = rotation_matrix_to_quaternion(rotation_matrix) | |
aa = quaternion_to_angle_axis(quaternion) | |
aa[torch.isnan(aa)] = 0.0 | |
return aa | |
default_checkpoint_body_smplx ='./extra_data/body_module/pretrained_weights/smplx-03-28-46060-w_spin_mlc3d_46582-2089_2020_03_28-21_56_16.pt' | |
smpl_dir = './extra_data/smpl/' | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
body_mocap = BodyMocap(default_checkpoint_body_smplx, smpl_dir, device = device, use_smplx= True) | |
smplx_model = body_mocap.smpl | |
left_hand_pose = torch.from_numpy(np.array([[0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, | |
0, 0]], dtype=np.float32)).to(device) | |
right_hand_pose = torch.from_numpy(np.array([[0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, | |
0, 0]], dtype=np.float32)).to(device) | |
source_path = '/PARE/logs/demo/test_/pare_results/000001.pkl' | |
data = joblib.load(source_path) | |
print('pare_result') | |
for key in data: | |
print(key, data[key].shape) | |
print() | |
betas = data['pred_shape'] | |
pred_pose = data['pred_pose'] | |
betas = torch.tensor(betas).to(device) | |
pred_pose = torch.tensor(pred_pose).to(device) | |
body_pose=pred_pose[:, 1:].contiguous().to(device) | |
global_orient=pred_pose[:, 0].unsqueeze(1).contiguous().to(device) | |
batch_size = 1 | |
body_pose2 = rotation_matrix_to_angle_axis(body_pose.reshape(-1, 3, 3)).reshape(batch_size, -1) | |
global_orient2 = rotation_matrix_to_angle_axis(global_orient.reshape(-1, 3, 3)).reshape(batch_size, -1) | |
print('saved2', body_pose2.shape) | |
print('global_orient2', global_orient2.shape) | |
smpl = SMPL('body_models/smpl', create_transl=False) | |
smplx_output = smplx_model( | |
betas = betas, | |
body_pose=body_pose2, | |
global_orient=global_orient2, | |
right_hand_pose = right_hand_pose, | |
left_hand_pose= left_hand_pose, | |
pose2rot = True) | |
sav_dir = os.path.join(os.getcwd(), 'output/temp') | |
with open(osp.join(sav_dir, os.path.basename(source_path)), 'wb') as f: | |
pred_vertices = smplx_output.vertices | |
pred_vertices = pred_vertices.detach().cpu().numpy() | |
pred_body_joints = smplx_output.joints | |
pred_body_joints = pred_body_joints[0].detach().cpu().numpy() | |
body_pose = smplx_output.body_pose | |
body_pose = body_pose.detach().cpu().numpy() | |
print('saved body_pose', body_pose.shape) | |
result = {'body_pose': body_pose, | |
'right_hand_pose': right_hand_pose.cpu().detach().numpy(), | |
'left_hand_pose': left_hand_pose.cpu().detach().numpy(), | |
'global_orient': global_orient.detach().cpu().numpy().astype(np.float32), | |
'v': pred_vertices} | |
pickle.dump(result, f) | |
mesh = trimesh.Trimesh(pred_vertices[0], smplx_model.faces) | |
mesh.export(osp.join(sav_dir, os.path.splitext(os.path.basename(source_path))[0] + '.obj')) | |
print('saved in ', osp.join(sav_dir, os.path.basename(source_path))) | |
print('saved in ', osp.join(sav_dir, os.path.splitext(os.path.basename(source_path))[0] + '.obj')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment