Skip to content

Instantly share code, notes, and snippets.

@jlmelville
Last active January 6, 2022 20:32
Show Gist options
  • Save jlmelville/9b4e5d076e719a7541881e8cbf58a895 to your computer and use it in GitHub Desktop.
Save jlmelville/9b4e5d076e719a7541881e8cbf58a895 to your computer and use it in GitHub Desktop.
The Kabsch algorithm in R for aligning one point set over another
#' Kabsch Algorithm
#'
#' Aligns two sets of points via rotations and translations.
#'
#' Given two sets of points, with one specified as the reference set,
#' the other set will be rotated so that the RMSD between the two is minimized.
#' The format of the matrix is that there should be one row for each of
#' n observations, and the number of columns, d, specifies the dimensionality
#' of the points. The point sets must be of equal size and with the same
#' ordering, i.e. point one of the second matrix is mapped to point one of
#' the reference matrix, point two of the second matrix is mapped to point two
#' of the reference matrix, and so on.
#'
#' @param pm n x d matrix of points to align to to \code{qm}.
#' @param qm n x d matrix of reference points.
#' @return Matrix \code{pm} rotated and translated so that the ith point
#' is aligned to the ith point of \code{qm} in the least-squares sense.
#' @references
#' \url{https://en.wikipedia.org/wiki/Kabsch_algorithm}
kabsch <- function(pm, qm) {
pm_dims <- dim(pm)
if (!all(dim(qm) == pm_dims)) {
stop(call. = TRUE, "Point sets must have the same dimensions")
}
# The rotation matrix will have (ncol - 1) leading ones in the diagonal
diag_ones <- rep(1, pm_dims[2] - 1)
# center the points
pm <- scale(pm, center = TRUE, scale = FALSE)
qm <- scale(qm, center = TRUE, scale = FALSE)
am <- crossprod(pm, qm)
svd_res <- svd(am)
# use the sign of the determinant to ensure a right-hand coordinate system
d <- determinant(tcrossprod(svd_res$v, svd_res$u))$sign
dm <- diag(c(diag_ones, d))
# rotation matrix
um <- svd_res$v %*% tcrossprod(dm, svd_res$u)
# Rotate and then translate to the original centroid location of qm
sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
}
@radt0005
Copy link

radt0005 commented Apr 9, 2017

Thanks! This was very helpful for finding rotation matrices for pairs of points in 3D. I had a little trouble with line 33 causing

Error in svd_res$u %*% dm : non-conformable arguments

and I noticed dm was 2x2 in my case it needed to be 3x3. Here's an easy fix to put in place of line 30, with an added error trap to be sure pm and qm are the same sizes:

  if(all(dim(qm) == dim(qm))){
    l = dim(qm)[2] - 1
  }else(stop(call. = T,"Point sets must have same dimensions."))
  
  dm <- diag(c(rep(1,l),d))

@jlmelville
Copy link
Author

Glad this was helpful. I have incorporated your corrections, thank you!

I have also fixed two embarrassing errors:

  1. The right and left singular vectors should have been swapped when generating the rotation matrix.
  2. The centering operation over pm was completely wrong.

Memo to self: in future, check that kabsch(pm, pm) == pm and similarly for qm before posting gist. D'oh.

@jlmelville
Copy link
Author

jlmelville commented Dec 7, 2017

Minor change to use crossprod and tcrossprod rather than %*% and t() directly.

@tcgriffith
Copy link

I believe there's still a bug. According to the wiki, the rotation matrix um is to rotate Pm unto Qm, so the last line should be

sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))

here is an example(sorry it's a bit long):

library(ggplot2)

kabsch <- function(pm, qm) {
  pm_dims <- dim(pm)
  if (!all(dim(qm) == pm_dims)) {
    stop(call. = TRUE, "Point sets must have the same dimensions")
  }
  # The rotation matrix will have (ncol - 1) leading ones in the diagonal
  diag_ones <- rep(1, pm_dims[2] - 1)

  # center the points
  pm <- scale(pm, center = TRUE, scale = FALSE)
  qm <- scale(qm, center = TRUE, scale = FALSE)

  am <- crossprod(pm, qm)

  svd_res <- svd(am)
  # use the sign of the determinant to ensure a right-hand coordinate system
  d <- determinant(tcrossprod(svd_res$v, svd_res$u))$sign
  dm <- diag(c(diag_ones, d))

  # rotation matrix
  um <- svd_res$v %*% tcrossprod(dm, svd_res$u)

  # Rotate and then translate to the original centroid location of pm
  sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
}


pm2 <- data.frame(
 x= c(2,1,1),
 y= c(2,2,1)
)

qm2 <- data.frame(
  x = c(0, 0, -1.5),
  y = c(0, 1.5, 1.5)
)


ggplot(pm2,aes(x,y))+
  geom_point(color = "red")+
  geom_path(color="red")+
  geom_point(data=qm2, aes(x=x,y=y),color = "blue")+
  geom_path(data=qm2, aes(x=x,y=y),color = "blue")+
  xlim(-3,3)+
  ylim(-3,3)

pm2.t <- kabsch(pm2,qm2)
pm2.t <- as.data.frame(pm2.t)
names(pm2.t) <- c("x","y")

ggplot(qm2,aes(x,y))+
  geom_point(color = "red")+
  geom_path(color="red")+
  geom_point(data=pm2.t, aes(x=x,y=y),color = "blue")+
  geom_path(data=pm2.t, aes(x=x,y=y),color = "blue")+
  xlim(-3,3)+
  ylim(-3,3)

Created on 2018-08-21 by the reprex package (v0.2.0.9000).

@jlmelville
Copy link
Author

Thank you @tcgriffith. I don't get alerted to comments, so apologies for not acknowledging this sooner.

This must be a personal record for the most errors in the least amount of code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment