Skip to content

Instantly share code, notes, and snippets.

@sherwoac
Last active January 19, 2021 14:25
Show Gist options
  • Save sherwoac/a039cacfb64d29b005a3ff5be8c426c6 to your computer and use it in GitHub Desktop.
Save sherwoac/a039cacfb64d29b005a3ff5be8c426c6 to your computer and use it in GitHub Desktop.
learnable_rotation_matrix in tensorflow
# cf from https://arxiv.org/abs/1812.07035
import tensorflow as tf
def rotation_matrix_to_label_6d_flat(rotation_matrices):
"""
takes batch x 3x3 rotation matrix, returns a flat batch x 6 floats for loss
:param rotation_matrices: batch x (3x3) rotation matrix
:return: flat batch x 6 floats for loss
"""
return tf.concat([rotation_matrices[:, :, 0], rotation_matrices[:, :, 1]], axis=1)
def label_6d_flat_to_rotation_matrix(rotation_6d_flat):
"""
opposite of above, converts 6d flat rotation representation into 3x3 rotation matrix
:param rotation_6d_flat: batch x 6d flat
:return: new batch x (3x3) rotation matrix
"""
batches = tf.shape(rotation_6d_flat)[0]
return_rotation_matrices_column_1 = rotation_6d_flat[:, :3]
return_rotation_matrices_column_1 = return_rotation_matrices_column_1 / tf.linalg.norm(return_rotation_matrices_column_1, axis=1, keepdims=True)
b1 = return_rotation_matrices_column_1
a2 = rotation_6d_flat[:, 3:]
return_rotation_matrices_column_2 = a2 - tf.expand_dims(tf.reduce_sum(tf.multiply(b1, a2), axis=1), 1) * b1
return_rotation_matrices_column_2 = return_rotation_matrices_column_2 / tf.linalg.norm(return_rotation_matrices_column_2, axis=1, keepdims=True)
return_rotation_matrices_column_3 = tf.linalg.cross(return_rotation_matrices_column_1, return_rotation_matrices_column_2)
matrix = tf.reshape(tf.concat([return_rotation_matrices_column_1, return_rotation_matrices_column_2, return_rotation_matrices_column_3], axis=-1), (batches, 3, 3))
return tf.transpose(matrix, [0, 2, 1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment