Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
super slow torch.jit.script for loop
@torch.jit.script
def drloop(V: torch.Tensor,T: torch.Tensor) -> torch.Tensor:
n = len(V)
Ri = torch.zeros((3,), dtype=torch.float32, device=V.device)
Ri[0] = -T[0][1]
Ri[1] = T[0][0]
Ri[2] = 0
Ris = [Ri[None]]
for i in range(n - 1):
V1 = V[i]
C: float = V1.dot(V1)
Ri_L: torch.Tensor = Ri - 2/C * Ri.dot(V1) * V1
Ti_L: torch.Tensor = T[i] - 2/C * T[i].dot(V1) * V1
V2: torch.Tensor = T[i+1] - Ti_L
C2: float = V2.dot(V2)
Ri = Ri_L - (2/C2) * V2.dot(Ri_L) * V2
Ris.append(Ri[None])
# type: torch.Tensor
return torch.cat(Ris)
def dr(points):
# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/12/Computation-of-rotation-minimizing-frames.pdf
n = len(points)
V = torch.gradient(points, dim=0)[0]
norms = torch.linalg.norm(V, dim=1)[:,None]
T = V[...] / norms
R = drloop(V,T)
R = R / torch.linalg.norm(R, dim=1)[:,None]
S = torch.cross(R,T)
return T,S,R
X = th.randn(1000,3, requires_grad=True)
t1 = time.time()
T,U,V = dr(X)
t2 = time.time()
print(t2-t1)
loss = (T+U+V).sum()
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment