Skip to content

Instantly share code, notes, and snippets.

@soumith
Created October 11, 2014 20:52
Show Gist options
  • Save soumith/1f7645f14738d39be2b5 to your computer and use it in GitHub Desktop.
Save soumith/1f7645f14738d39be2b5 to your computer and use it in GitHub Desktop.
CuDNN SpatialMaxPooling bug
require 'cudnn'
require 'cunn'
local cudnntest = {}
local precision_forward = 1e-4
local precision_backward = 1e-2
local precision_jac = 1e-3
local nloop = 1
local times = {}
function cudnntest.SpatialMaxPooling()
local bs = 10
local from = 34
local ki = 4
local kj = 3
local si = 4
local sj = 3
local outi = 62
local outj = 90
local ini = (outi-1)*si+ki
local inj = (outj-1)*sj+kj
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,from,outj,outi):cuda()
local sconv = nn.SpatialMaxPooling(ki,kj,si,sj):cuda()
local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
local gconv = cudnn.SpatialMaxPooling(ki,kj,si,sj):cuda()
local rescuda = gconv:forward(input)
-- serialize and deserialize
local rescuda = gconv:forward(input)
local resgrad = gconv:backward(input, gradOutput)
cutorch.synchronize()
local error = rescuda:float() - groundtruth:float()
mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
error = resgrad:float() - groundgrad:float()
mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
end
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())
mytester = torch.Tester()
mytester:add(cudnntest)
torch.manualSeed(10)
print(i)
mytester:run(tests)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment