Skip to content

Instantly share code, notes, and snippets.

@lebedov
Created September 28, 2016 14:01
Show Gist options
  • Save lebedov/5160317776ba808e163aebd2884c2913 to your computer and use it in GitHub Desktop.
Save lebedov/5160317776ba808e163aebd2884c2913 to your computer and use it in GitHub Desktop.
Recursively show layer output tensor sizes in a Torch model.
-- Recursively show layer output tensor sizes in a Torch model.
require 'nn'
require 'torch'
function join(list, sep)
local sep = sep or ' '
return table.concat(list, sep)
end
function show_layer_sizes(input, m)
local m = m:clone()
local output = m:forward(input)
function rec(m, layer)
local count = 1
for k, v in pairs(m.modules) do
-- Print layer's output size:
if v['output'] ~= nil then
local result = string.rep('-', layer) .. string.format(' (%s) ', count)
if torch.type(v['output']) == 'table' then
for _, t in pairs(v['output']) do
local size_str = join(torch.totable(t:size()), 'x')
result = result .. size_str .. ' '
end
else
local size_str = join(torch.totable(v['output']:size()), 'x')
result = result .. size_str
end
print(result)
end
-- Recurse into layer's submodules:
if v['modules'] ~= nil then
rec(v, layer+1)
end
count = count + 1
end
end
rec(m, 1)
end
-- Demo using a model with branches:
do
encoder = nn.Sequential()
encoder:add(nn.SpatialConvolutionMM(1, 20, 5, 5))
encoder:add(nn.SpatialMaxPooling(2, 2, 2, 2))
encoder:add(nn.SpatialConvolutionMM(20, 50, 5, 5))
encoder:add(nn.SpatialMaxPooling(2, 2, 2, 2))
encoder:add(nn.View(50*4*4))
encoder:add(nn.Linear(50*4*4, 500))
encoder:add(nn.ReLU())
encoder:add(nn.Linear(500, 10))
encoder:add(nn.Linear(10, 2))
siamese_encoder = nn.ParallelTable()
siamese_encoder:add(encoder)
siamese_encoder:add(encoder:clone('weight', 'bias', 'gradWeight', 'gradBias'))
model = nn.Sequential()
model:add(nn.SplitTable(2))
model:add(siamese_encoder)
model:add(nn.PairwiseDistance(2))
x = torch.randn(10, 2, 1, 28, 28)
show_layer_sizes(x, model)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment