Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Batched L2 Normalization Layer for Torch nn package
--[[
This layer expects an [n x d] Tensor and normalizes each
row to have unit L2 norm.
]]--
local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
function L2Normalize:__init()
parent.__init(self)
end
function L2Normalize:updateOutput(input)
assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
.. input:dim() .. 'D tensor instead')
self.output:resizeAs(input)
self.buffer = self.buffer or input.new()
self.normSquared = self.normSquared or input.new()
self.normSquared:sum(self.buffer:cmul(input, input), 2)
self.buffer:sqrt(self.normSquared)
self.output:copy(input):cdiv(self.buffer:expandAs(input))
return self.output
end
function L2Normalize:updateGradInput(input, gradOutput)
assert(input:dim() == 2, 'only mini-batch supported')
assert(gradOutput:dim() == 2, 'only mini-batch supported')
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term
self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
self.diag = self.diag or self.eye.new()
self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d))
-- compute cross term
local b1 = input:view(n,d,1)
local b2 = input:view(n,1,d)
self.diag:add(-torch.bmm(b1,b2))
-- compute the local gradient of the L2 transformation
self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d))
-- chain the gradient
self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
return self.gradInput
end
@Atcold

This comment has been minimized.

Copy link

@Atcold Atcold commented May 5, 2015

Any reason for not using torch.norm()?
You are doing this, right?

\[\frac{\partial\frac{\boldsymbol{x}}{\sqrt{\boldsymbol{x}^\top\boldsymbol{x}}}}{\partial\boldsymbol{x}} = \frac{\mathbb{I}}{\sqrt{\boldsymbol{x}^\top\boldsymbol{x}}} - \frac{\boldsymbol{x}\boldsymbol{x}^\top}{\sqrt{\left(\boldsymbol{x}^\top\boldsymbol{x}\right)^3}}\]
@Atcold

This comment has been minimized.

Copy link

@Atcold Atcold commented May 5, 2015

Shouldn't line 31 be followed by :squeeze() for matching the dimensions? I have some funky gradInput dimensions otherwise..

@karpathy

This comment has been minimized.

Copy link
Owner Author

@karpathy karpathy commented May 5, 2015

Do you mean using norm() in forward pass? That could be done.

Oops, you're right about squeeze, fixed!

@Atcold

This comment has been minimized.

Copy link

@Atcold Atcold commented May 6, 2015

I saw @soumith gave you other pointers as well.
It's nice having you joining the Torch circle 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.