Skip to content

Instantly share code, notes, and snippets.

@woo1
Created November 16, 2021 04:23
Show Gist options
  • Save woo1/210e7480f4cd65b69e426fae09a9b46f to your computer and use it in GitHub Desktop.
Save woo1/210e7480f4cd65b69e426fae09a9b46f to your computer and use it in GitHub Desktop.
PARE smpl output pickle data to smplx using frankmocap module
import torch
import numpy as np
import sys
import os
sys.path.append(os.getcwd())
from bodymocap.body_mocap_api import BodyMocap
import pickle
import joblib
import os.path as osp
import trimesh
from smplx import SMPL as _SMPL
from bodymocap import constants
from smplx.lbs import vertices2joints
from smplx.utils import SMPLOutput
class SMPL(_SMPL):
""" Extension of the official SMPL implementation to support more joints """
def __init__(self, *args, **kwargs):
super(SMPL, self).__init__(*args, **kwargs)
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
JOINT_REGRESSOR_TRAIN_EXTRA = 'extra_data/body_module/data_from_spin//J_regressor_extra.npy'
J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
self.joint_map = torch.tensor(joints, dtype=torch.long)
def forward(self, *args, **kwargs):
kwargs['get_skin'] = True
smpl_output = super(SMPL, self).forward(*args, **kwargs)
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
joints = joints[:, self.joint_map, :]
output = SMPLOutput(vertices=smpl_output.vertices,
global_orient=smpl_output.global_orient,
body_pose=smpl_output.body_pose,
joints=joints,
betas=smpl_output.betas,
full_pose=smpl_output.full_pose)
return output
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
"""
This function is borrowed from https://github.com/kornia/kornia
Convert 3x4 rotation matrix to 4d quaternion vector
This algorithm is based on algorithm described in
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
Args:
rotation_matrix (Tensor): the rotation matrix to convert.
Return:
Tensor: the rotation in quaternion
Shape:
- Input: :math:`(N, 3, 4)`
- Output: :math:`(N, 4)`
Example:
>>> input = torch.rand(4, 3, 4) # Nx3x4
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
"""
if not torch.is_tensor(rotation_matrix):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(rotation_matrix)))
if len(rotation_matrix.shape) > 3:
raise ValueError(
"Input size must be a three dimensional tensor. Got {}".format(
rotation_matrix.shape))
if not rotation_matrix.shape[-2:] == (3, 4):
raise ValueError(
"Input size must be a N x 3 x 4 tensor. Got {}".format(
rotation_matrix.shape))
rmat_t = torch.transpose(rotation_matrix, 1, 2)
mask_d2 = rmat_t[:, 2, 2] < eps
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
t0_rep = t0.repeat(4, 1).t()
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
t1_rep = t1.repeat(4, 1).t()
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
t2_rep = t2.repeat(4, 1).t()
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
t3_rep = t3.repeat(4, 1).t()
mask_c0 = mask_d2 * mask_d0_d1
mask_c1 = mask_d2 * ~mask_d0_d1
mask_c2 = ~mask_d2 * mask_d0_nd1
mask_c3 = ~mask_d2 * ~mask_d0_nd1
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
q *= 0.5
return q
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
"""
This function is borrowed from https://github.com/kornia/kornia
Convert quaternion vector to angle axis of rotation.
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
Args:
quaternion (torch.Tensor): tensor with quaternions.
Return:
torch.Tensor: tensor with angle axis of rotation.
Shape:
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
- Output: :math:`(*, 3)`
Example:
>>> quaternion = torch.rand(2, 4) # Nx4
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
"""
if not torch.is_tensor(quaternion):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(quaternion)))
if not quaternion.shape[-1] == 4:
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
.format(quaternion.shape))
# unpack input and compute conversion
q1: torch.Tensor = quaternion[..., 1]
q2: torch.Tensor = quaternion[..., 2]
q3: torch.Tensor = quaternion[..., 3]
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
cos_theta: torch.Tensor = quaternion[..., 0]
two_theta: torch.Tensor = 2.0 * torch.where(
cos_theta < 0.0,
torch.atan2(-sin_theta, -cos_theta),
torch.atan2(sin_theta, cos_theta))
k_pos: torch.Tensor = two_theta / sin_theta
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
angle_axis[..., 0] += q1 * k
angle_axis[..., 1] += q2 * k
angle_axis[..., 2] += q3 * k
return angle_axis
def rotation_matrix_to_angle_axis(rotation_matrix):
"""
Convert 3x4 rotation matrix to Rodrigues vector
Args:
rotation_matrix (Tensor): rotation matrix.
Returns:
Tensor: Rodrigues vector transformation.
Shape:
- Input: :math:`(N, 3, 4)`
- Output: :math:`(N, 3)`
Example:
>>> input = torch.rand(2, 3, 4) # Nx4x4
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
"""
if rotation_matrix.shape[1:] == (3,3):
hom_mat = torch.tensor([0, 0, 1]).float()
rot_mat = rotation_matrix.reshape(-1, 3, 3)
batch_size, device = rot_mat.shape[0], rot_mat.device
hom_mat = hom_mat.view(1, 3, 1)
hom_mat = hom_mat.repeat(batch_size, 1, 1).contiguous()
hom_mat = hom_mat.to(device)
rotation_matrix = torch.cat([rot_mat, hom_mat], dim=-1)
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
aa = quaternion_to_angle_axis(quaternion)
aa[torch.isnan(aa)] = 0.0
return aa
default_checkpoint_body_smplx ='./extra_data/body_module/pretrained_weights/smplx-03-28-46060-w_spin_mlc3d_46582-2089_2020_03_28-21_56_16.pt'
smpl_dir = './extra_data/smpl/'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
body_mocap = BodyMocap(default_checkpoint_body_smplx, smpl_dir, device = device, use_smplx= True)
smplx_model = body_mocap.smpl
left_hand_pose = torch.from_numpy(np.array([[0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0]], dtype=np.float32)).to(device)
right_hand_pose = torch.from_numpy(np.array([[0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0]], dtype=np.float32)).to(device)
source_path = '/PARE/logs/demo/test_/pare_results/000001.pkl'
data = joblib.load(source_path)
print('pare_result')
for key in data:
print(key, data[key].shape)
print()
betas = data['pred_shape']
pred_pose = data['pred_pose']
betas = torch.tensor(betas).to(device)
pred_pose = torch.tensor(pred_pose).to(device)
body_pose=pred_pose[:, 1:].contiguous().to(device)
global_orient=pred_pose[:, 0].unsqueeze(1).contiguous().to(device)
batch_size = 1
body_pose2 = rotation_matrix_to_angle_axis(body_pose.reshape(-1, 3, 3)).reshape(batch_size, -1)
global_orient2 = rotation_matrix_to_angle_axis(global_orient.reshape(-1, 3, 3)).reshape(batch_size, -1)
print('saved2', body_pose2.shape)
print('global_orient2', global_orient2.shape)
smpl = SMPL('body_models/smpl', create_transl=False)
smplx_output = smplx_model(
betas = betas,
body_pose=body_pose2,
global_orient=global_orient2,
right_hand_pose = right_hand_pose,
left_hand_pose= left_hand_pose,
pose2rot = True)
sav_dir = os.path.join(os.getcwd(), 'output/temp')
with open(osp.join(sav_dir, os.path.basename(source_path)), 'wb') as f:
pred_vertices = smplx_output.vertices
pred_vertices = pred_vertices.detach().cpu().numpy()
pred_body_joints = smplx_output.joints
pred_body_joints = pred_body_joints[0].detach().cpu().numpy()
body_pose = smplx_output.body_pose
body_pose = body_pose.detach().cpu().numpy()
print('saved body_pose', body_pose.shape)
result = {'body_pose': body_pose,
'right_hand_pose': right_hand_pose.cpu().detach().numpy(),
'left_hand_pose': left_hand_pose.cpu().detach().numpy(),
'global_orient': global_orient.detach().cpu().numpy().astype(np.float32),
'v': pred_vertices}
pickle.dump(result, f)
mesh = trimesh.Trimesh(pred_vertices[0], smplx_model.faces)
mesh.export(osp.join(sav_dir, os.path.splitext(os.path.basename(source_path))[0] + '.obj'))
print('saved in ', osp.join(sav_dir, os.path.basename(source_path)))
print('saved in ', osp.join(sav_dir, os.path.splitext(os.path.basename(source_path))[0] + '.obj'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment