Skip to content

Instantly share code, notes, and snippets.

@mkocabas
Created September 14, 2020 08:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mkocabas/6e17e91222cc1492b9886a00cba6e9f1 to your computer and use it in GitHub Desktop.
Save mkocabas/6e17e91222cc1492b9886a00cba6e9f1 to your computer and use it in GitHub Desktop.
####### LOSS FUNCTION #######
class MultivariateGaussianNegativeLogLikelihood(nn.Module):
def __init__(self):
super(MultivariateGaussianNegativeLogLikelihood, self).__init__()
def forward(self, pred_mean, pred_var, gt):
mu = pred_mean
logsigma = pred_var
mse = -0.5 * torch.sum(torch.square((gt - mu) / torch.exp(logsigma)), dim=1)
sigma_trace = -torch.sum(torch.exp(logsigma), dim=1)
log2pi = -0.5 * n_dims * np.log(2 * np.pi)
logger.debug(f'\nMSE: {mse.mean().item():.2f}'
f' Sigma: {sigma_trace.mean().item():.2f}'
f' log2pi:{log2pi.mean():.2f}')
log_likelihood = mse + sigma_trace + log2pi
return torch.mean(-log_likelihood)
####### HMR HEAD #######
class hmr_head(nn.Module):
def __init__(
self,
num_input_features,
smpl_mean_params=SMPL_MEAN_PARAMS,
):
super(hmr_head, self).__init__()
npose = 24 * 6
self.npose = npose
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc1 = nn.Linear(num_input_features + npose + 13, 1024)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout()
# Double the MLP output for pose and shape
self.decpose = nn.Linear(1024, npose * 2)
self.decshape = nn.Linear(1024, 10 * 2)
self.deccam = nn.Linear(1024, 3)
mean_params = np.load(smpl_mean_params)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.register_buffer('init_cam', init_cam)
def forward(
self,
features,
init_pose=None,
init_shape=None,
init_cam=None,
n_iter=3
):
batch_size = features.shape[0]
if init_pose is None:
init_pose = self.init_pose.expand(batch_size, -1)
if init_shape is None:
init_shape = self.init_shape.expand(batch_size, -1)
if init_cam is None:
init_cam = self.init_cam.expand(batch_size, -1)
xf = self.avgpool(features)
xf = xf.view(xf.size(0), -1)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for i in range(n_iter):
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc)[:,:self.npose] + pred_pose
pred_shape = self.decshape(xc)[:,:10] + pred_shape
pred_cam = self.deccam(xc) + pred_cam
pred_pose_var = self.decpose(xc)[:,self.npose:]
pred_shape_var = self.decshape(xc)[:,10:]
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
output = {
'pred_pose': pred_rotmat,
'pred_cam': pred_cam,
'pred_shape': pred_shape,
'pred_pose_6d': pred_pose,
'pred_pose_6d_var': pred_pose_var,
'pred_shape_var': pred_shape_var,
}
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment