Skip to content

Instantly share code, notes, and snippets.

@nagadomi
Created October 21, 2018 00:50
Show Gist options
  • Save nagadomi/fcbb5cd7e39079d34bd2aa62fff665c9 to your computer and use it in GitHub Desktop.
Save nagadomi/fcbb5cd7e39079d34bd2aa62fff665c9 to your computer and use it in GitHub Desktop.
Torch7 SEBlock (Squeeze and Excitation Networks)
require 'nn'
local ScaleTable, parent = torch.class("nn.ScaleTable", "nn.Module")
function ScaleTable:__init()
parent.__init(self)
self.gradInput = {}
self.grad_tmp = torch.Tensor()
self.scale = torch.Tensor()
end
function ScaleTable:updateOutput(input)
assert(#input == 2)
assert(input[1]:size(2) == input[2]:size(2))
self.scale:resizeAs(input[1]):expandAs(input[2], input[1])
self.output:resizeAs(self.scale):copy(self.scale)
self.output:cmul(input[1])
return self.output
end
function ScaleTable:updateGradInput(input, gradOutput)
self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput)
self.gradInput[1]:cmul(self.scale)
self.grad_tmp:resizeAs(input[1]):copy(gradOutput)
self.grad_tmp:cmul(input[1])
self.gradInput[2] = self.gradInput[2] or input[2].new()
self.gradInput[2]:resizeAs(input[2]):sum(self.grad_tmp:reshape(self.grad_tmp:size(1), self.grad_tmp:size(2), self.grad_tmp:size(3) * self.grad_tmp:size(4)), 3):resizeAs(input[2])
for i=#input+1, #self.gradInput do
self.gradInput[i] = nil
end
return self.gradInput
end
-- usage
local function example_net()
local function resblock(i, o, use_se_block)
local seq = nn.Sequential()
local con = nn.ConcatTable()
local conv = nn.Sequential()
conv:add(nn.SpatialConvolution(i, o, 3, 3, 1, 1, 0, 0))
conv:add(nn.LeakyReLU(0.1, true))
conv:add(nn.SpatialConvolution(o, o, 3, 3, 1, 1, 0, 0))
conv:add(nn.LeakyReLU(0.1, true))
if use_se_block then
local se_con = nn.ConcatTable()
local se_block = nn.Sequential()
local r = 4
local mid = math.floor(o / r)
se_block:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
se_block:add(nn.SpatialConvolution(o, mid, 1, 1, 1, 1, 0, 0))
se_block:add(nn.ReLU(true))
se_block:add(nn.SpatialConvolution(mid, o, 1, 1, 1, 1, 0, 0))
se_block:add(nn.Sigmoid(true))
se_con:add(nn.Identity())
se_con:add(se_block)
conv:add(se_con)
conv:add(nn.ScaleTable())
end
con:add(conv)
if i == o then
con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
else
local seq = nn.Sequential()
seq:add(nn.SpatialConvolution(i, o, 1, 1, 1, 1, 0, 0))
seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
con:add(seq)
end
seq:add(con)
seq:add(nn.CAddTable())
return seq
end
local ch = 3
local model = nn.Sequential()
model:add(nn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(resblock(32, 64, true))
model:add(resblock(64, 64, true))
model:add(resblock(64, 64, true))
model:add(resblock(64, 128, true))
model:add(nn.SpatialFullConvolution(128, ch, 4, 4, 2, 2, 3, 3):noBias())
-- run
print(model)
print(model:forward(torch.Tensor(4, 3, 32, 32):uniform()):size())
end
--example_net()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment