Skip to content

Instantly share code, notes, and snippets.

@jlmelville
Last active January 6, 2022 20:32
Show Gist options
  • Save jlmelville/9b4e5d076e719a7541881e8cbf58a895 to your computer and use it in GitHub Desktop.
Save jlmelville/9b4e5d076e719a7541881e8cbf58a895 to your computer and use it in GitHub Desktop.
The Kabsch algorithm in R for aligning one point set over another
#' Kabsch Algorithm
#'
#' Aligns two sets of points via rotations and translations.
#'
#' Given two sets of points, with one specified as the reference set,
#' the other set will be rotated so that the RMSD between the two is minimized.
#' The format of the matrix is that there should be one row for each of
#' n observations, and the number of columns, d, specifies the dimensionality
#' of the points. The point sets must be of equal size and with the same
#' ordering, i.e. point one of the second matrix is mapped to point one of
#' the reference matrix, point two of the second matrix is mapped to point two
#' of the reference matrix, and so on.
#'
#' @param pm n x d matrix of points to align to to \code{qm}.
#' @param qm n x d matrix of reference points.
#' @return Matrix \code{pm} rotated and translated so that the ith point
#' is aligned to the ith point of \code{qm} in the least-squares sense.
#' @references
#' \url{https://en.wikipedia.org/wiki/Kabsch_algorithm}
kabsch <- function(pm, qm) {
pm_dims <- dim(pm)
if (!all(dim(qm) == pm_dims)) {
stop(call. = TRUE, "Point sets must have the same dimensions")
}
# The rotation matrix will have (ncol - 1) leading ones in the diagonal
diag_ones <- rep(1, pm_dims[2] - 1)
# center the points
pm <- scale(pm, center = TRUE, scale = FALSE)
qm <- scale(qm, center = TRUE, scale = FALSE)
am <- crossprod(pm, qm)
svd_res <- svd(am)
# use the sign of the determinant to ensure a right-hand coordinate system
d <- determinant(tcrossprod(svd_res$v, svd_res$u))$sign
dm <- diag(c(diag_ones, d))
# rotation matrix
um <- svd_res$v %*% tcrossprod(dm, svd_res$u)
# Rotate and then translate to the original centroid location of qm
sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
}
@jlmelville
Copy link
Author

Glad this was helpful. I have incorporated your corrections, thank you!

I have also fixed two embarrassing errors:

  1. The right and left singular vectors should have been swapped when generating the rotation matrix.
  2. The centering operation over pm was completely wrong.

Memo to self: in future, check that kabsch(pm, pm) == pm and similarly for qm before posting gist. D'oh.

@jlmelville
Copy link
Author

jlmelville commented Dec 7, 2017

Minor change to use crossprod and tcrossprod rather than %*% and t() directly.

@tcgriffith
Copy link

I believe there's still a bug. According to the wiki, the rotation matrix um is to rotate Pm unto Qm, so the last line should be

sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))

here is an example(sorry it's a bit long):

library(ggplot2)

kabsch <- function(pm, qm) {
  pm_dims <- dim(pm)
  if (!all(dim(qm) == pm_dims)) {
    stop(call. = TRUE, "Point sets must have the same dimensions")
  }
  # The rotation matrix will have (ncol - 1) leading ones in the diagonal
  diag_ones <- rep(1, pm_dims[2] - 1)

  # center the points
  pm <- scale(pm, center = TRUE, scale = FALSE)
  qm <- scale(qm, center = TRUE, scale = FALSE)

  am <- crossprod(pm, qm)

  svd_res <- svd(am)
  # use the sign of the determinant to ensure a right-hand coordinate system
  d <- determinant(tcrossprod(svd_res$v, svd_res$u))$sign
  dm <- diag(c(diag_ones, d))

  # rotation matrix
  um <- svd_res$v %*% tcrossprod(dm, svd_res$u)

  # Rotate and then translate to the original centroid location of pm
  sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
}


pm2 <- data.frame(
 x= c(2,1,1),
 y= c(2,2,1)
)

qm2 <- data.frame(
  x = c(0, 0, -1.5),
  y = c(0, 1.5, 1.5)
)


ggplot(pm2,aes(x,y))+
  geom_point(color = "red")+
  geom_path(color="red")+
  geom_point(data=qm2, aes(x=x,y=y),color = "blue")+
  geom_path(data=qm2, aes(x=x,y=y),color = "blue")+
  xlim(-3,3)+
  ylim(-3,3)

pm2.t <- kabsch(pm2,qm2)
pm2.t <- as.data.frame(pm2.t)
names(pm2.t) <- c("x","y")

ggplot(qm2,aes(x,y))+
  geom_point(color = "red")+
  geom_path(color="red")+
  geom_point(data=pm2.t, aes(x=x,y=y),color = "blue")+
  geom_path(data=pm2.t, aes(x=x,y=y),color = "blue")+
  xlim(-3,3)+
  ylim(-3,3)

Created on 2018-08-21 by the reprex package (v0.2.0.9000).

@jlmelville
Copy link
Author

Thank you @tcgriffith. I don't get alerted to comments, so apologies for not acknowledging this sooner.

This must be a personal record for the most errors in the least amount of code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment