Skip to content

Instantly share code, notes, and snippets.

@podgorskiy
Last active February 26, 2020 06:34
Show Gist options
  • Save podgorskiy/708c92fec4ef4f3a583e8e9f541c0134 to your computer and use it in GitHub Desktop.
Save podgorskiy/708c92fec4ef4f3a583e8e9f541c0134 to your computer and use it in GitHub Desktop.
Torch module that represents learnable parametrized rotation matrix
import torch
from torch import nn
import numpy as np
class RotationMatrix(torch.nn.Module):
def __init__(self):
super(RotationMatrix, self).__init__()
self.betta = nn.Parameter(torch.tensor(np.random.randn(3), dtype=torch.float32), requires_grad=True)
def forward(self):
theta = torch.dot(self.betta, self.betta) ** 0.5
k = self.betta / theta
K = torch.tensor([
[0, -k[2], k[1]],
[k[2], 0, -k[0]],
[-k[1], k[0], 0]
], dtype=torch.float32)
R = torch.eye(3) + torch.sin(theta) * K + (1.0 - torch.cos(theta)) * K ** 2.0
return R
# some test
if __name__ == "__main__":
from torch.optim.sgd import SGD
from torch.nn.functional import mse_loss
def random_rotation(size):
shape = [size, size]
a = np.random.normal(-1.0, 1.0, shape)
u, s, v = np.linalg.svd(a, full_matrices=False)
return u
# random matrix (ground-truth)
R = random_rotation(3)
# bunch of vectors
X = np.random.randn(3, 10)
# rotated vectors
Xp = np.matmul(R, X)
################################
# Input: X, Xp
# Output: learned R
################################
X = torch.tensor(X, dtype=torch.float32)
Xp = torch.tensor(Xp, dtype=torch.float32)
R_ = RotationMatrix()
optim = SGD(R_.parameters(), 0.01)
for i in range(1000):
R_.zero_grad()
loss = mse_loss(torch.matmul(R_(), X), Xp)
loss.backward()
print(loss.item())
optim.step()
print("GT:\n", R)
print("Learned:\n", R_())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment