Instantly share code, notes, and snippets.

# jlmelville/kabsch.R

Last active January 6, 2022 20:32
Show Gist options
• Save jlmelville/9b4e5d076e719a7541881e8cbf58a895 to your computer and use it in GitHub Desktop.
The Kabsch algorithm in R for aligning one point set over another
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 #' Kabsch Algorithm #' #' Aligns two sets of points via rotations and translations. #' #' Given two sets of points, with one specified as the reference set, #' the other set will be rotated so that the RMSD between the two is minimized. #' The format of the matrix is that there should be one row for each of #' n observations, and the number of columns, d, specifies the dimensionality #' of the points. The point sets must be of equal size and with the same #' ordering, i.e. point one of the second matrix is mapped to point one of #' the reference matrix, point two of the second matrix is mapped to point two #' of the reference matrix, and so on. #' #' @param pm n x d matrix of points to align to to \code{qm}. #' @param qm n x d matrix of reference points. #' @return Matrix \code{pm} rotated and translated so that the ith point #' is aligned to the ith point of \code{qm} in the least-squares sense. #' @references #' \url{https://en.wikipedia.org/wiki/Kabsch_algorithm} kabsch <- function(pm, qm) { pm_dims <- dim(pm) if (!all(dim(qm) == pm_dims)) { stop(call. = TRUE, "Point sets must have the same dimensions") } # The rotation matrix will have (ncol - 1) leading ones in the diagonal diag_ones <- rep(1, pm_dims[2] - 1) # center the points pm <- scale(pm, center = TRUE, scale = FALSE) qm <- scale(qm, center = TRUE, scale = FALSE) am <- crossprod(pm, qm) svd_res <- svd(am) # use the sign of the determinant to ensure a right-hand coordinate system d <- determinant(tcrossprod(svd_res\$v, svd_res\$u))\$sign dm <- diag(c(diag_ones, d)) # rotation matrix um <- svd_res\$v %*% tcrossprod(dm, svd_res\$u) # Rotate and then translate to the original centroid location of qm sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center")) }

### radt0005 commented Apr 9, 2017

Thanks! This was very helpful for finding rotation matrices for pairs of points in 3D. I had a little trouble with line 33 causing

``````Error in svd_res\$u %*% dm : non-conformable arguments
``````

and I noticed dm was 2x2 in my case it needed to be 3x3. Here's an easy fix to put in place of line 30, with an added error trap to be sure pm and qm are the same sizes:

``````  if(all(dim(qm) == dim(qm))){
l = dim(qm)[2] - 1
}else(stop(call. = T,"Point sets must have same dimensions."))

dm <- diag(c(rep(1,l),d))

``````

### jlmelville commented Aug 27, 2017

I have also fixed two embarrassing errors:

1. The right and left singular vectors should have been swapped when generating the rotation matrix.
2. The centering operation over `pm` was completely wrong.

Memo to self: in future, check that `kabsch(pm, pm) == pm` and similarly for `qm` before posting gist. D'oh.

Minor change to use `crossprod` and `tcrossprod` rather than `%*%` and `t()` directly.

### tcgriffith commented Aug 21, 2018

I believe there's still a bug. According to the wiki, the rotation matrix `um` is to rotate `Pm` unto `Qm`, so the last line should be

`sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))`

here is an example(sorry it's a bit long):

```library(ggplot2)

kabsch <- function(pm, qm) {
pm_dims <- dim(pm)
if (!all(dim(qm) == pm_dims)) {
stop(call. = TRUE, "Point sets must have the same dimensions")
}
# The rotation matrix will have (ncol - 1) leading ones in the diagonal
diag_ones <- rep(1, pm_dims[2] - 1)

# center the points
pm <- scale(pm, center = TRUE, scale = FALSE)
qm <- scale(qm, center = TRUE, scale = FALSE)

am <- crossprod(pm, qm)

svd_res <- svd(am)
# use the sign of the determinant to ensure a right-hand coordinate system
d <- determinant(tcrossprod(svd_res\$v, svd_res\$u))\$sign
dm <- diag(c(diag_ones, d))

# rotation matrix
um <- svd_res\$v %*% tcrossprod(dm, svd_res\$u)

# Rotate and then translate to the original centroid location of pm
sweep(t(tcrossprod(um, pm)), 2, -attr(qm, "scaled:center"))
}

pm2 <- data.frame(
x= c(2,1,1),
y= c(2,2,1)
)

qm2 <- data.frame(
x = c(0, 0, -1.5),
y = c(0, 1.5, 1.5)
)

ggplot(pm2,aes(x,y))+
geom_point(color = "red")+
geom_path(color="red")+
geom_point(data=qm2, aes(x=x,y=y),color = "blue")+
geom_path(data=qm2, aes(x=x,y=y),color = "blue")+
xlim(-3,3)+
ylim(-3,3)```

```pm2.t <- kabsch(pm2,qm2)
pm2.t <- as.data.frame(pm2.t)
names(pm2.t) <- c("x","y")

ggplot(qm2,aes(x,y))+
geom_point(color = "red")+
geom_path(color="red")+
geom_point(data=pm2.t, aes(x=x,y=y),color = "blue")+
geom_path(data=pm2.t, aes(x=x,y=y),color = "blue")+
xlim(-3,3)+
ylim(-3,3)```

Created on 2018-08-21 by the reprex package (v0.2.0.9000).

### jlmelville commented Nov 27, 2018

Thank you @tcgriffith. I don't get alerted to comments, so apologies for not acknowledging this sooner.

This must be a personal record for the most errors in the least amount of code.