Skip to content

Instantly share code, notes, and snippets.

@soumith
Last active August 29, 2015 14:01
Show Gist options
  • Save soumith/2003b3de9d98cdbfa294 to your computer and use it in GitHub Desktop.
Save soumith/2003b3de9d98cdbfa294 to your computer and use it in GitHub Desktop.
require 'torch'
require 'nn'
local mytester = torch.Tester()
local precision = 1e-5
local critest = {}
function critest.MSECriterion()
local input = torch.rand(100)
local target = input:clone():add(torch.rand(100))
local cri = nn.MSECriterion()
local eps = 1e-6
local fx = cri:forward(input, target)
local dfdx = cri:backward(input, target)
-- for each input perturbation, do central difference
local centraldiff_dfdx = torch.Tensor(100)
for i=1,100 do
-- f(xi + h)
input[i] = input[i] + eps
local fx1 = cri:forward(input, target)
-- f(xi - h)
input[i] = input[i] - 2*eps
local fx2 = cri:forward(input, target)
-- f'(xi) = (f(xi + h) - f(xi - h)) / 2h
local cdfx = (fx1 - fx2) / (2*eps)
-- store f' in appropriate place
centraldiff_dfdx[i] = cdfx
-- reset input[i]
input[i] = input[i] + eps
end
-- compare centraldiff_dfdx with :backward()
print()
for i=1,5 do
print(centraldiff_dfdx[i], dfdx[i])
end
local err = (centraldiff_dfdx - dfdx):abs():max()
mytester:assertlt(err, precision, 'error on state ')
end
mytester:add(critest)
mytester:run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment