Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created September 3, 2015 19:28
Show Gist options
  • Save jcjohnson/55e46dd4d33a874aa209 to your computer and use it in GitHub Desktop.
Save jcjohnson/55e46dd4d33a874aa209 to your computer and use it in GitHub Desktop.
require 'torch'
require 'cutorch'
require 'nn'
require 'cunn'
require 'cudnn'
local N = 32
local cin = 64
local cout = 64
local height = 256
local width = 256
local k = 3
local nn_mod = nn.SpatialConvolution(cin, cout, k, k, 1, 1, 1, 1):float()
local cunn_mod = nn.SpatialConvolution(cin, cout, k, k, 1, 1, 1, 1):cuda()
local mm_mod = nn.SpatialConvolutionMM(cin, cout, k, k, 1, 1, 1, 1):cuda()
local cudnn_mod = cudnn.SpatialConvolution(cin, cout, k, k, 1, 1, 1, 1):cuda()
local input = torch.randn(N, cin, height, width):float()
local dout = torch.randn(N, cout, height, width):float()
local input_cuda = input:cuda()
local dout_cuda = dout:cuda()
local timer = torch.Timer()
local num_trials = 10
local function time_module(mod, x, dy, name)
local forward_times = torch.DoubleTensor(num_trials)
local backward_times = torch.DoubleTensor(num_trials)
for t = 1, num_trials do
cutorch.synchronize()
timer:reset()
mod:forward(x)
cutorch.synchronize()
if t > 1 then forward_times[t] = timer:time().real end
timer:reset()
mod:backward(x, dy)
cutorch.synchronize()
if t > 1 then backward_times[t] = timer:time().real end
end
local f_mean, f_std = forward_times:mean(), forward_times:std()
local f_min, f_max = forward_times:min(), forward_times:max()
local f_msg = '%s forward took %f +- %f [%f, %f]'
print(string.format(f_msg, name, f_mean, f_std, f_min, f_max))
local b_mean, b_std = backward_times:mean(), backward_times:std()
local b_min, b_max = backward_times:min(), backward_times:max()
local b_msg = '%s backward took %f += %f [%f, %f]'
print(string.format(b_msg, name, b_mean, b_std, b_min, b_max))
end
-- time_module(nn_mod, input, dout, 'CPU')
time_module(cunn_mod, input_cuda, dout_cuda, 'cunn')
time_module(mm_mod, input_cuda, dout_cuda, 'MM')
time_module(cudnn_mod, input_cuda, dout_cuda, 'cudnn')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment