Skip to content

Instantly share code, notes, and snippets.

@bougui505
Last active April 9, 2023 07:17
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bougui505/e392a371f5bab095a3673ea6f4976cc8 to your computer and use it in GitHub Desktop.
Save bougui505/e392a371f5bab095a3673ea6f4976cc8 to your computer and use it in GitHub Desktop.
Rigid alignment between points (Kabsch algorithm). Pytorch implementation
#!/usr/bin/env python3
# -*- coding: UTF8 -*-
# Author: Guillaume Bouvier -- guillaume.bouvier@pasteur.fr
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-10-01 11:39:39 (UTC+0200)
import torch
import numpy as np
def find_rigid_alignment(A, B):
"""
See: https://en.wikipedia.org/wiki/Kabsch_algorithm
2-D or 3-D registration with known correspondences.
Registration occurs in the zero centered coordinate system, and then
must be transported back.
Args:
- A: Torch tensor of shape (N,D) -- Point Cloud to Align (source)
- B: Torch tensor of shape (N,D) -- Reference Point Cloud (target)
Returns:
- R: optimal rotation
- t: optimal translation
Test on rotation + translation and on rotation + translation + reflection
>>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
>>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
>>> B = (R0.mm(A.T)).T
>>> t0 = torch.tensor([3., 3.])
>>> B += t0
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
>>> B *= torch.tensor([-1., 1.])
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
"""
a_mean = A.mean(axis=0)
b_mean = B.mean(axis=0)
A_c = A - a_mean
B_c = B - b_mean
# Covariance matrix
H = A_c.T.mm(B_c)
U, S, V = torch.svd(H)
# Rotation matrix
R = V.mm(U.T)
# Translation vector
t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
t = t.T
return R, t.squeeze()
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment