Skip to content

Instantly share code, notes, and snippets.

@pengsun
Last active April 30, 2016 02:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pengsun/46ee5f67ec611252c0a58450f57768d5 to your computer and use it in GitHub Desktop.
Save pengsun/46ee5f67ec611252c0a58450f57768d5 to your computer and use it in GitHub Desktop.
require'cunn'
require'cudnn'
--V = 30000 + 1 -- vocabulary size
--C = 500
--M = 80 -- seq length
--B = 100 -- #batches
--padVocabInd = 1
--MP = M * p
V = 30000 + 1 -- vocabulary size
C = 500
M = 500 -- seq length
B = 100 -- #batches
padVocabInd = 1
nloop = 3
-- onehot input
input = torch.LongTensor(B, M):random(V):cuda()
weight = torch.CudaTensor(V, C):normal()
function timing_module(input, m)
local time
-- fprop
m:forward(input) -- warm up
time = torch.tic()
for i = 1, nloop do
m:forward(input)
end
cutorch.synchronize()
time = torch.toc(time)
print(torch.type(m) .. ' fprop time ' .. time/nloop)
end
-- lookuptable
m1 = nn.LookupTable(V, C):cuda()
m1.weight:copy(weight)
m1:setPadding(padVocabInd)
print('lookuptable')
--print(m)
timing_module(input, m1)
output1 = m1:forward(input)
-- lookuptableNew
m2 = ohnn.LookupTableNew(V,C):cuda()
m2.weight:copy(weight)
m2:setPadding(padVocabInd)
print('nn.lookuptableNew')
timing_module(input, m2)
output2 = m2:forward(input)
-- verify diff
function calc_diff(a, b)
local c = a:view(-1) - b:view(-1)
return c:abs():max()
end
d = calc_diff(output1, output2)
print( ('diff = %f'):format(d) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment