Skip to content

Instantly share code, notes, and snippets.

@isarandi
Last active June 12, 2023 12:23
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save isarandi/95918dcf02c2ed5cf3db50613e5aaee7 to your computer and use it in GitHub Desktop.
Save isarandi/95918dcf02c2ed5cf3db50613e5aaee7 to your computer and use it in GitHub Desktop.
Procrustes transformation, implemented in TensorFlow. Procrustes analysis takes two sets of corresponding points and computes a rigid (or similarity) transformation that aligns them best, in a least square sense.
# Copyright 2021 Istvan Sarandi
# MIT License
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
def procrustes(X, Y, validity_mask, allow_scaling=False, allow_reflection=False):
"""Register the points in Y by rotation, translation, uniform scaling (optional) and
reflection (optional)
to be closest to the corresponding points in X, in a least-squares sense.
This function operates on batches. For each item in the batch a separate
transform is computed independently of the others.
Arguments:
X: Tensor with shape [batch_size, n_points, point_dimensionality]
Y: Tensor with shape [batch_size, n_points, point_dimensionality]
validity_mask: Boolean Tensor with shape [batch_size, n_points] indicating
whether a point is valid in X
allow_scaling: boolean, specifying whether uniform scaling is allowed
allow_reflection: boolean, specifying whether reflections are allowed
Returns the transformed version of Y.
"""
validity_mask = validity_mask[..., np.newaxis]
zeros = tf.zeros_like(X)
n_points_per_example = tf.math.count_nonzero(
validity_mask, axis=1, dtype=tf.float32, keepdims=True)
denominator_correction_factor = validity_mask.shape[1] / n_points_per_example
def normalize(Z):
Z = tf.where(validity_mask, Z, zeros)
mean = tf.reduce_mean(Z, axis=1, keepdims=True) * denominator_correction_factor
centered = tf.where(validity_mask, Z - mean, zeros)
norm = tf.norm(centered, axis=(1, 2), ord='fro', keepdims=True)
normalized = centered / norm
return mean, norm, normalized
meanX, normX, normalizedX = normalize(X)
meanY, normY, normalizedY = normalize(Y)
A = tf.linalg.matrix_transpose(normalizedY) @ normalizedX
s, U, V = tf.linalg.svd(A, full_matrices=False)
T = U @ tf.linalg.matrix_transpose(V)
s = s[:, :, np.newaxis]
if allow_scaling:
relative_scale = normX / normY
output_scale = relative_scale * tf.reduce_sum(s, axis=1, keepdims=True)
else:
relative_scale = None
output_scale = 1
if not allow_reflection:
# Check if T has a reflection component. If so, then remove it by flipping
# across the direction of least variance, i.e. the last singular value/vector.
has_reflection = (tf.linalg.det(T) < 0)[..., np.newaxis, np.newaxis]
T_mirror = T - 2 * tf.einsum('Ni,Nk->Nik', U[..., -1], V[..., -1])
T = tf.where(has_reflection, T_mirror, T)
if allow_scaling:
output_scale_mirror = output_scale - 2 * relative_scale * s[:, -1:]
output_scale = tf.where(has_reflection, output_scale_mirror, output_scale)
return ((Y - meanY) @ T) * output_scale + meanX
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment