Created
October 9, 2019 12:31
-
-
Save mkocabas/54ea2ff3b03260e3fedf8ad22536f427 to your computer and use it in GitHub Desktop.
Pytorch batch procrustes implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
def compute_similarity_transform(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.T | |
S2 = S2.T | |
transposed = True | |
assert(S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=1, keepdims=True) | |
mu2 = S2.mean(axis=1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# 2. Compute variance of X1 used for scale. | |
var1 = np.sum(X1**2) | |
# 3. The outer product of X1 and X2. | |
K = X1.dot(X2.T) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, Vh = np.linalg.svd(K) | |
V = Vh.T | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = np.eye(U.shape[0]) | |
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) | |
# Construct R. | |
R = V.dot(Z.dot(U.T)) | |
# 5. Recover scale. | |
scale = np.trace(R.dot(K)) / var1 | |
# 6. Recover translation. | |
t = mu2 - scale*(R.dot(mu1)) | |
# 7. Error: | |
S1_hat = scale*R.dot(S1) + t | |
if transposed: | |
S1_hat = S1_hat.T | |
return S1_hat | |
def compute_similarity_transform_torch(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.T | |
S2 = S2.T | |
transposed = True | |
assert (S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=1, keepdims=True) | |
mu2 = S2.mean(axis=1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# print('X1', X1.shape) | |
# 2. Compute variance of X1 used for scale. | |
var1 = torch.sum(X1 ** 2) | |
# print('var', var1.shape) | |
# 3. The outer product of X1 and X2. | |
K = X1.mm(X2.T) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, V = torch.svd(K) | |
# V = Vh.T | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(U.shape[0], device=S1.device) | |
Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) | |
# Construct R. | |
R = V.mm(Z.mm(U.T)) | |
# print('R', X1.shape) | |
# 5. Recover scale. | |
scale = torch.trace(R.mm(K)) / var1 | |
# print(R.shape, mu1.shape) | |
# 6. Recover translation. | |
t = mu2 - scale * (R.mm(mu1)) | |
# print(t.shape) | |
# 7. Error: | |
S1_hat = scale * R.mm(S1) + t | |
if transposed: | |
S1_hat = S1_hat.T | |
return S1_hat | |
def batch_compute_similarity_transform_torch(S1, S2): | |
''' | |
Computes a similarity transform (sR, t) that takes | |
a set of 3D points S1 (3 x N) closest to a set of 3D points S2, | |
where R is an 3x3 rotation matrix, t 3x1 translation, s scale. | |
i.e. solves the orthogonal Procrutes problem. | |
''' | |
transposed = False | |
if S1.shape[0] != 3 and S1.shape[0] != 2: | |
S1 = S1.permute(0,2,1) | |
S2 = S2.permute(0,2,1) | |
transposed = True | |
assert(S2.shape[1] == S1.shape[1]) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=-1, keepdims=True) | |
mu2 = S2.mean(axis=-1, keepdims=True) | |
X1 = S1 - mu1 | |
X2 = S2 - mu2 | |
# 2. Compute variance of X1 used for scale. | |
var1 = torch.sum(X1**2, dim=1).sum(dim=1) | |
# 3. The outer product of X1 and X2. | |
K = X1.bmm(X2.permute(0,2,1)) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | |
# singular vectors of K. | |
U, s, V = torch.svd(K) | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) | |
Z = Z.repeat(U.shape[0],1,1) | |
Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) | |
# Construct R. | |
R = V.bmm(Z.bmm(U.permute(0,2,1))) | |
# 5. Recover scale. | |
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 | |
# 6. Recover translation. | |
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) | |
# 7. Error: | |
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t | |
if transposed: | |
S1_hat = S1_hat.permute(0,2,1) | |
return S1_hat |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@dusangrujicic does this solve it? https://github.com/brando90/ultimate-anatome/blob/7d23ca83bac7201a91c5515145269d2288f7bbb6/anatome/similarity.py#L518