Skip to content

Instantly share code, notes, and snippets.

@bougui505
Created October 1, 2020 07:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bougui505/d769009e84cfbec2258fb6dc936c507d to your computer and use it in GitHub Desktop.
Save bougui505/d769009e84cfbec2258fb6dc936c507d to your computer and use it in GitHub Desktop.
Build a rotation matrix in Pytorch
#!/usr/bin/env python
# -*- coding: UTF8 -*-
# Author: Guillaume Bouvier -- guillaume.bouvier@pasteur.fr
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-10-01 09:43:30 (UTC+0200)
import torch
def build_rotation_matrix(alpha_beta_gamma, device):
alpha, beta, gamma = alpha_beta_gamma
tensor_0 = torch.zeros(1, device=device)
tensor_1 = torch.ones(1, device=device)
alpha = torch.ones(1, requires_grad=True, device=device) * alpha
beta = torch.ones(1, requires_grad=True, device=device) * beta
gamma = torch.ones(1, requires_grad=True, device=device) * gamma
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]),
torch.stack([tensor_0, torch.cos(alpha), -torch.sin(alpha)]),
torch.stack([tensor_0, torch.sin(alpha), torch.cos(alpha)])]).reshape(3, 3)
RY = torch.stack([torch.stack([torch.cos(beta), tensor_0, torch.sin(beta)]),
torch.stack([tensor_0, tensor_1, tensor_0]),
torch.stack([-torch.sin(beta), tensor_0, torch.cos(beta)])]).reshape(3, 3)
RZ = torch.stack([torch.stack([torch.cos(gamma), -torch.sin(gamma), tensor_0]),
torch.stack([torch.sin(gamma), torch.cos(gamma), tensor_0]),
torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
R = RZ.mm(RY).mm(RX)
return R
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment