Skip to content

Instantly share code, notes, and snippets.

@farrajota
Last active July 26, 2017 10:33
Show Gist options
  • Save farrajota/309f39776d4509cc6c00 to your computer and use it in GitHub Desktop.
Save farrajota/309f39776d4509cc6c00 to your computer and use it in GitHub Desktop.
Fully connected to convolution layer
require 'nn'
-- you just need to provide the linear module you want to convert,
-- and the dimensions of the field of view of the linear layer
function convertLinear2Conv1x1(linmodule,in_size)
--[[
Convert Linear modules to convolution modules.
Arguments
linmodule - pointer of the module to be scanned for convertion
in_size - 2x1 table containing the convolution stride (stride_x, stride_y) for the convolution module.
Return values
convmodule - output convolution module (nn.SpatialConvolution())
Example:
input = torch.rand(3,6,6)
m = nn.Linear(3*6*6,10)
mm = convertLinear2Conv1x1(m,{6,6})
output_lin = m:forward(input:view(3*6*6))
output_conv = mm:forward(input)
--]]
local s_in = linmodule.weight:size(2)/(in_size[1]*in_size[2])
local s_out = linmodule.weight:size(1)
local convmodule = nn.SpatialConvolution(s_in,s_out,in_size[1],in_size[2],1,1)
convmodule.weight:copy(linmodule.weight)
convmodule.bias:copy(linmodule.bias)
return convmodule
end
function convertLinear2Conv1x1_v2(linmodule,in_size, modulepointer)
--[[
Convert Linear modules to convolution modules. Additionally, the convolutional module can be specified.
Arguments
linmodule - pointer of the module to be scanned for convertion.
in_size - 2x1 table containing the convolution stride (stride_x, stride_y) for the convolution module.
modulepointer - Pointer to the convolutional module to be used for convolution (eg: nn.SpatialConvolution, cudnn.SpatialConvolition)
Return values
convmodule - output convolution module
Example:
input = torch.rand(3,6,6)
m = nn.Linear(3*6*6,10)
mm = convertLinear2Conv1x1_v2(m,{6,6}, nn.SpatialConvolution)
output_lin = m:forward(input:view(3*6*6))
output_conv = mm:forward(input)
--]]
local s_in = linmodule.weight:size(2)/(in_size[1]*in_size[2])
local s_out = linmodule.weight:size(1)
local moduletype = modulepointer or nn.SpatialConvolution
local convmodule = moduletype(s_in,s_out,in_size[1],in_size[2],1,1)
convmodule.weight:copy(linmodule.weight)
convmodule.bias:copy(linmodule.bias)
return convmodule
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment