Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created October 11, 2015 23:52
Show Gist options
  • Save jcjohnson/04e649e285dbf07690db to your computer and use it in GitHub Desktop.
Save jcjohnson/04e649e285dbf07690db to your computer and use it in GitHub Desktop.
require 'torch'
require 'cutorch'
require 'nn'
require 'cunn'
require 'cudnn'
require 'loadcaffe'
local cmd = torch.CmdLine()
cmd:option('-model', 'alexnet')
cmd:option('-backend', 'nn')
cmd:option('-gpu', 0)
cmd:option('-batch_size', 10)
local params = cmd:parse(arg)
local model_file, proto_file, size
if params.model == 'alexnet' then
proto_file = 'models/bvlc_alexnet/deploy.prototxt'
model_file = 'models/bvlc_alexnet/bvlc_alexnet.caffemodel'
size = 227
elseif params.model == 'caffenet' then
proto_file = 'models/bvlc_reference_caffenet/deploy.prototxt'
model_file = 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
size = 227
elseif params.model == 'vgg-16' then
proto_file = 'models/vgg-16/VGG_ILSVRC_16_layers_deploy.prototxt'
model_file = 'models/vgg-16/VGG_ILSVRC_16_layers.caffemodel'
size = 224
elseif params.model == 'vgg-19' then
proto_file = 'models/vgg-19/VGG_ILSVRC_19_layers_deploy.prototxt'
model_file = 'models/vgg-19/VGG_ILSVRC_19_layers.caffemodel'
size = 224
else
error(string.format('Unrecognized model "%s"', params.model))
end
cutorch.setDevice(params.gpu + 1)
local cnn = loadcaffe.load(proto_file, model_file, params.backend):cuda()
local data = torch.randn(params.batch_size, 3, size, size):cuda()
local dout = nil
cutorch.synchronize()
local timer = torch.Timer()
local num_iterations = 50
local forward_times = torch.Tensor(num_iterations)
local backward_times = torch.Tensor(num_iterations)
for i = 1, num_iterations do
cutorch.synchronize()
timer:reset()
local out = cnn:forward(data)
cutorch.synchronize()
local forward_time = timer:time().real * 1000
if not dout then
dout = torch.randn(#out):cuda()
cutorch.synchronize()
end
timer:reset()
local din = cnn:backward(data, dout)
cutorch.synchronize()
local backward_time = timer:time().real * 1000
local msg = 'Iteration %d / %d, forward %s ms, backward %s ms'
print(string.format(msg, i, num_iterations, forward_time, backward_time))
forward_times[i] = forward_time
backward_times[i] = backward_time
end
print(string.format('Mean forward time: %f', forward_times:mean()))
print(string.format('Mean backward time: %f', backward_times:mean()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment