Skip to content

Instantly share code, notes, and snippets.

@shrubb
Created August 4, 2017 21:00
Show Gist options
  • Save shrubb/fcba898593ad72014b579dc86806cac6 to your computer and use it in GitHub Desktop.
Save shrubb/fcba898593ad72014b579dc86806cac6 to your computer and use it in GitHub Desktop.
A rough benchmark script for depthwise convolution
require 'cunn'
torch.setdefaulttensortype('torch.FloatTensor')
torch.manualSeed(666)
cutorch.manualSeed(666)
-- change these as needed
-- no automated testing, sorry :(
local nInputPlane, nOutputPlane, inputHeight, inputWidth = 128, 1, 52, 52
input = torch.rand(160, nInputPlane, inputHeight, inputWidth):cuda()
-- *** depthwise ***
--require 'cudnn'; module = cudnn.SpatialConvolution(nInputPlane, nInputPlane, 3,3, 1,1, 1,1, nInputPlane)
module = nn.SpatialDepthWiseConvolution(nInputPlane, nOutputPlane, 3,3, 1,1, 1,1)
-- *** full ***
--require 'cudnn'; module = cudnn.SpatialConvolution(128, 128, 3,3, 1,1, 1,1)
--module = nn.SpatialConvolution(128, 128, 3,3, 1,1, 1,1)
module:cuda()
local nRepeats = 10 -- just to make sure GPU is still
local nRepeatsInner = 4
local times = torch.Tensor(nRepeats)
module:forward(input)
for k = 1,nRepeats do
io.stdout:write(k .. ', '):flush()
collectgarbage()
cutorch.synchronize()
local timer = torch.Timer()
for _ = 1,nRepeatsInner do
module:forward(input)
end
cutorch.synchronize()
times[k] = timer:time().real
end
--print(times)
print('Forward average: ' .. times[{{2,-1}}]:mean() .. ' +/- ' .. times[{{2,-1}}]:std() * 2.5)
gradOutput = torch.rand(module.output:size())
module:backward(input, gradOutput)
for k = 1,nRepeats do
io.stdout:write(k .. ', '):flush()
collectgarbage()
cutorch.synchronize()
local timer = torch.Timer()
for _ = 1,nRepeatsInner do
--module:updateGradInput(input, gradOutput)
--module:accGradParameters(input, gradOutput)
module:backward(input, gradOutput)
end
cutorch.synchronize()
times[k] = timer:time().real
end
--print(times)
print('Backward average: ' .. times[{{2,-1}}]:mean() .. ' +/- ' .. times[{{2,-1}}]:std() * 2.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment