Skip to content

Instantly share code, notes, and snippets.

@aluo-x
Created July 24, 2020 16:45
Show Gist options
  • Save aluo-x/d2c27c07bdcb6af2895a8c83e2b63bae to your computer and use it in GitHub Desktop.
Save aluo-x/d2c27c07bdcb6af2895a8c83e2b63bae to your computer and use it in GitHub Desktop.
elif self.DYN_ARCH in ['single-fc-leaky-wide-discrete']:
objs_state, agents_state = state
object_class = extra
agent_feat = tf.reshape(agents_state, [B, N_AGENT_JOINTS * AGENT_DIM])
##### Only allow 1 onj!!!!! #########
obj_feat = objs_state[:, 0]
net_input = tf.concat([obj_feat, agent_feat, action], axis=-1)
backbone_out = self.fcnet(net_input, is_train)
net_output = self.offsetnet(backbone_out, is_train)
obj_delta_partial = net_output[:, :9]
agent_delta = self.agentnet(backbone_out, is_train)
rotation_out = self.rotationet(backbone_out, is_train)
residual_in = tf.concat([tf.nn.softmax(rotation_out, axis=1), backbone_out], axis=-1)
residual = tf.linalg.normalize(tf.reshape(self.residualnet(residual_in, is_train), (backbone_out.shape[0], self.config.rot_heads, 4)), axis=2)[0]
new_rot = tf.linalg.normalize(tfq.quaternion_multiply(tf.tile(tf.convert_to_tensor(self.config.quat, dtype="float32")[None], (backbone_out.shape[0], 1, 1)), residual), axis=2)[0]
orig = np.array([[1.0, 0.0, 0.0]])
# orig of shape 1, 3
# reshape to 1, 1 3
orig2 = tf.convert_to_tensor(np.tile(orig[None], [backbone_out.shape[0], self.config.rot_heads, 1]), dtype="float32")
S2_candidate = tf.linalg.normalize(tfq.rotate_vector_by_quaternion(tfq.Quaternion(new_rot), orig2), axis=2)[0]
# B, K, 3 * B, K
S2_weighted_mean = tf.linalg.normalize(tf.reduce_mean(S2_candidate*tf.nn.softmax(rotation_out*10.0, axis=1)[:, :, None], axis=1), axis=1)[0]
orig3 = tf.convert_to_tensor(np.tile(orig, [backbone_out.shape[0], 1]), dtype="float32")
crossprod = tf.linalg.cross(orig3, S2_weighted_mean)
dotprod = tf.einsum('ij,ij->i', orig3, S2_weighted_mean)
qw = 1.0 + dotprod[:, None]
converted_quat = tf.stop_gradient(tf.linalg.normalize(tf.concat([qw, crossprod], axis=1), axis=1)[0])
obj_delta = tf.concat([obj_delta_partial[:, :3], converted_quat, obj_delta_partial[:, 3:]], axis=1)
objs_delta = tf.stack([obj_delta, tf.zeros_like(obj_delta)], axis=1)
agents_delta = tf.expand_dims(agent_delta, axis=1)
objs_delta = utils.basic.normalize_quat(objs_delta)
agents_delta = utils.basic.normalize_quat(agents_delta)
return objs_delta, agents_delta, residual, new_rot, tf.nn.softmax(rotation_out, axis=1)
###################################
####### Loss
if len(residual)>0:
orig = np.array([[1.0, 0.0, 0.0, 0.0]])
residual_for_loss = tf.concat(residual, axis=0)
orig2 = tf.convert_to_tensor(np.tile(orig[None], [residual_for_loss.shape[0], residual_for_loss.shape[1], 1]),
dtype="float32")
residual_dot = tf.reduce_mean(1.0 - tf.reduce_sum(orig2*residual_for_loss, axis=2)**2.0)*0.05
total_loss += residual_dot
vis["scalar-d_objs_residual_loss"] = residual_dot
if (not (rots is None)) and (len(rots)>0):
gt_xyz_delta, gt_orn_delta, gt_vel_delta = utils.basic.split_states(gt_objs_delta_state, mode="h13")
stacked_rots = tf.stack(rots, axis=1)
# Now of shape B, T, K, 4
gt_orn_obj_delta = tf.tile(gt_orn_delta[:, :,:1, :], tf.constant([1,1,self.config.rot_heads, 1], tf.int32))
rots_loss_unweighted = 1.0- tf.square(tf.reduce_sum(input_tensor=tf.multiply(gt_orn_obj_delta, stacked_rots), axis=-1))
# print(rot_weights[0].shape, len(rot_weights), tf.stack(rot_weights, axis=1).shape)
rots_loss_weighted = rots_loss_unweighted*tf.stack(rot_weights, axis=1)
total_rot_loss = tf.reduce_mean(input_tensor=rots_loss_weighted)
vis["scalar-d_objs_weighted_degree_loss"] = total_rot_loss
total_loss+= total_rot_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment