Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Batched L2 Normalization Layer for Torch nn package
--[[
This layer expects an [n x d] Tensor and normalizes each
row to have unit L2 norm.
]]--
local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
function L2Normalize:__init()
parent.__init(self)
end
function L2Normalize:updateOutput(input)
assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
.. input:dim() .. 'D tensor instead')
self.output:resizeAs(input)
self.buffer = self.buffer or input.new()
self.normSquared = self.normSquared or input.new()
self.normSquared:sum(self.buffer:cmul(input, input), 2)
self.buffer:sqrt(self.normSquared)
self.output:copy(input):cdiv(self.buffer:expandAs(input))
return self.output
end
function L2Normalize:updateGradInput(input, gradOutput)
assert(input:dim() == 2, 'only mini-batch supported')
assert(gradOutput:dim() == 2, 'only mini-batch supported')
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term
self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
self.diag = self.diag or self.eye.new()
self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d))
-- compute cross term
local b1 = input:view(n,d,1)
local b2 = input:view(n,1,d)
self.diag:add(-torch.bmm(b1,b2))
-- compute the local gradient of the L2 transformation
self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d))
-- chain the gradient
self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
return self.gradInput
end
@Atcold

This comment has been minimized.

Copy link

Atcold commented May 5, 2015

Any reason for not using torch.norm()?
You are doing this, right?

\[\frac{\partial\frac{\boldsymbol{x}}{\sqrt{\boldsymbol{x}^\top\boldsymbol{x}}}}{\partial\boldsymbol{x}} = \frac{\mathbb{I}}{\sqrt{\boldsymbol{x}^\top\boldsymbol{x}}} - \frac{\boldsymbol{x}\boldsymbol{x}^\top}{\sqrt{\left(\boldsymbol{x}^\top\boldsymbol{x}\right)^3}}\]
@Atcold

This comment has been minimized.

Copy link

Atcold commented May 5, 2015

Shouldn't line 31 be followed by :squeeze() for matching the dimensions? I have some funky gradInput dimensions otherwise..

@karpathy

This comment has been minimized.

Copy link
Owner Author

karpathy commented May 5, 2015

Do you mean using norm() in forward pass? That could be done.

Oops, you're right about squeeze, fixed!

@Atcold

This comment has been minimized.

Copy link

Atcold commented May 6, 2015

I saw @soumith gave you other pointers as well.
It's nice having you joining the Torch circle 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.