Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@taey16
Created October 25, 2015 16:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taey16/a8f6e7bf36710a4e2293 to your computer and use it in GitHub Desktop.
Save taey16/a8f6e7bf36710a4e2293 to your computer and use it in GitHub Desktop.
fcn
require 'nn'
-- you just need to provide the linear module you want to convert,
-- and the dimensions of the field of view of the linear layer
function convertLinear2Conv1x1(linmodule,in_size)
local s_in = linmodule.weight:size(2)/(in_size[1]*in_size[2])
local s_out = linmodule.weight:size(1)
local convmodule = nn.SpatialConvolutionMM(s_in,s_out,in_size[1],in_size[2],1,1)
convmodule.weight:copy(linmodule.weight)
convmodule.bias:copy(linmodule.bias)
return convmodule
end
input = torch.rand(3,6,6)
m = nn.Linear(3*6*6,10)
mm = convertLinear2Conv1x1(m,{6,6})
output_lin = m:forward(input:view(3*6*6))
output_conv = mm:forward(input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment