Skip to content

Instantly share code, notes, and snippets.

@andreaskoepf
Created October 19, 2015 22:26
Show Gist options
  • Save andreaskoepf/9a0cd440b675be647d27 to your computer and use it in GitHub Desktop.
Save andreaskoepf/9a0cd440b675be647d27 to your computer and use it in GitHub Desktop.
experimental torch rotation distance criterion for quaternions
local RotationDistance, parent = torch.class('RotationDistance', 'nn.Criterion')
function RotationDistance:__init(weights)
parent.__init(self)
end
function RotationDistance:updateOutput(input, target)
-- acos(abs(<a,b> / norm(a) / norm(b))) + abs(1 - norm(a))
local one = torch.ones(input:size(1)):cuda()
local a = torch.cmul(input, target):sum(2):cdiv(torch.cmul(input:norm(2, 2), target:norm(2, 2))):abs():clamp(-1, 1):acos()
local b = (one - input:norm(2, 2)):abs()
self.output = (a + b):mean()
return self.output
end
function RotationDistance:updateGradInput(input, target)
self.gradInput:resizeAs(input)
self.gradInput:zero()
-- scaling factor
local s = torch.cmul(input:norm(2, 2), target:norm(2, 2))
-- scaled inner product
local k = torch.cmul(input, target):sum(2):cdiv(s)
-- limit abs-value for stability of following division (avoid NaNs)
local l = torch.abs(k):clamp(0, 0.99999)
-- d/dx acos = 1 / (1 - x^2)
local v = -torch.sign(k):cdiv(torch.ones(l:size()):cuda() - torch.pow(l, 2))
-- towards target ..
local a = torch.cmul(v:expand(target:size()), target)
-- compute target point
local w = input - a
-- project on n-sphere (normalize target point)
w:cdiv(w:norm(2, 2):expand(input:size()))
-- compute difference between input and
self.gradInput:copy(input - w)
return self.gradInput
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment