Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Last active April 20, 2016 11:09
Show Gist options
  • Save jcjohnson/61d23297d6ee67b065e5 to your computer and use it in GitHub Desktop.
Save jcjohnson/61d23297d6ee67b065e5 to your computer and use it in GitHub Desktop.
Simple torch benchmarking tool for fully-connected networks
require 'nn'
require 'cutorch'
require 'cunn'
--[[
-- A simple benchmark comparing fully-connected net times on CPU and GPU.
--
-- Note that we don't count time it takes to transfer data to the GPU.
--]]
local cmd = torch.CmdLine()
cmd:option('-input_dim', 100)
cmd:option('-output_dim', 4)
cmd:option('-hidden_dim', 4096)
cmd:option('-hidden_layers', 5)
cmd:option('-batch_size', 1000)
cmd:option('-num_trials', 5)
cmd:option('-quiet', false)
cmd:option('-gpu', 0)
local opt = cmd:parse(arg)
cutorch.setDevice(opt.gpu + 1)
-- Build the model
local model = nn.Sequential()
model:add(nn.Linear(opt.input_dim, opt.hidden_dim))
for i = 1, opt.hidden_layers do
model:add(nn.Linear(opt.hidden_dim, opt.hidden_dim))
model:add(nn.ReLU(true))
end
model:add(nn.Linear(opt.hidden_dim, opt.output_dim))
local crit = nn.MSECriterion()
local dtypes = {'torch.FloatTensor', 'torch.CudaTensor'}
local mean_times = {}
local timer = torch.Timer()
for _, dtype in ipairs(dtypes) do
print(string.format('Testing dtype %s', dtype))
model:type(dtype)
crit:type(dtype)
local times = torch.DoubleTensor(opt.num_trials)
for t = 1, opt.num_trials do
local X = torch.randn(opt.batch_size, opt.input_dim):type(dtype)
local y = torch.randn(opt.batch_size, opt.output_dim):type(dtype)
cutorch.synchronize()
timer:reset()
local y_pred = model:forward(X)
local loss = crit:forward(y_pred, y)
local dy_pred = crit:backward(y_pred, y)
model:backward(X, dy_pred)
cutorch.synchronize()
local time = timer:time().real
times[t] = time
if not opt.quiet then
print(time)
end
end
local mean_time = times:mean()
table.insert(mean_times, mean_time)
print(string.format('Mean time: %f', mean_time))
end
print(string.format('GPU speedup: %f', mean_times[1] / mean_times[2]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment