Skip to content

Instantly share code, notes, and snippets.

@mjhong0708
Last active December 12, 2023 10:07
Show Gist options
  • Save mjhong0708/00f72e64155c6480a6e0e3c9d3e57e18 to your computer and use it in GitHub Desktop.
Save mjhong0708/00f72e64155c6480a6e0e3c9d3e57e18 to your computer and use it in GitHub Desktop.
RMSD by Kabsch algorithm
from typing import Tuple
import torch
Tensor = torch.Tensor
def find_alignment_kabsch(P: Tensor, Q: Tensor) -> Tuple[Tensor, Tensor]:
"""Find alignment using Kabsch algorithm between two sets of points P and Q.
Args:
P (torch.Tensor): A tensor of shape (N, 3) representing the first set of points.
Q (torch.Tensor): A tensor of shape (N, 3) representing the second set of points.
Returns:
Tuple[Tensor, Tensor]: A tuple containing two tensors, where the first tensor is the rotation matrix R
and the second tensor is the translation vector t. The rotation matrix R is a tensor of shape (3, 3)
representing the optimal rotation between the two sets of points, and the translation vector t
is a tensor of shape (3,) representing the optimal translation between the two sets of points.
"""
# Shift points w.r.t centroid
centroid_P, centroid_Q = P.mean(dim=0), Q.mean(dim=0)
P_c, Q_c = P - centroid_P, Q - centroid_Q
# Find rotation matrix by Kabsch algorithm
H = P_c.T @ Q_c
U, S, Vt = torch.linalg.svd(H)
V = Vt.T
# ensure right-handedness
d = torch.sign(torch.linalg.det(V @ U.T))
# Trick for torch.vmap
diag_values = torch.cat(
[
torch.ones(1, dtype=P.dtype, device=P.device),
torch.ones(1, dtype=P.dtype, device=P.device),
d * torch.ones(1, dtype=P.dtype, device=P.device),
]
)
# This is only [[1,0,0],[0,1,0],[0,0,d]]
M = torch.eye(3, dtype=P.dtype, device=P.device) * diag_values
R = V @ M @ U.T
# Find translation vectors
t = centroid_Q[None, :] - (R @ centroid_P[None, :].T).T
t = t.T
return R, t.squeeze()
def calculate_rmsd(pos: Tensor, ref: Tensor) -> Tensor:
"""
Calculate the root mean square deviation (RMSD) between two sets of points pos and ref.
Args:
pos (torch.Tensor): A tensor of shape (N, 3) representing the positions of the first set of points.
ref (torch.Tensor): A tensor of shape (N, 3) representing the positions of the second set of points.
Returns:
torch.Tensor: RMSD between the two sets of points.
"""
if pos.shape[0] != ref.shape[0]:
raise ValueError("pos and ref must have the same number of points")
R, t = find_alignment_kabsch(ref, pos)
ref0 = (R @ ref.T).T + t
rmsd = torch.linalg.norm(ref0 - pos, dim=1).mean()
return rmsd
# vmap requires pytorch >= 2.0
def calculate_rmsd_matrix(R: Tensor) -> Tensor:
fn_vmap_row = torch.vmap(calculate_rmsd, in_dims=(0, None))
fn_vmap_row_col = torch.vmap(fn_vmap_row, in_dims=(None, 0))
return fn_vmap_row_col(R, R)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment