Created
March 5, 2015 15:47
-
-
Save culurciello/1c1a85b2791bf574d49b to your computer and use it in GitHub Desktop.
precision test for Torch7 and new hardware
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ precision-test | |
Compare precision of hardware and software implementation | |
run with: qlua precision-network.lua | |
--]] | |
require 'nn' | |
require 'pl' | |
require 'image' | |
torch.setdefaulttensortype('torch.FloatTensor') | |
local iC = 1 -- input channel | |
local iH = 128 -- input size | |
local iW = iH | |
local kC = 4 -- nb kernels | |
local kH = 3 -- kernel size | |
local kW = kH | |
local pH = 2 -- pool size | |
local pW = pH | |
-- define network and weights | |
network = nn.Sequential() | |
network:add(nn.SpatialConvolutionMM(iC, kC, kH, kW)) | |
-- network:add(nn.Threshold()) | |
-- network:add(nn.SpatialMaxPooling(pH, pW, pH, pW)) | |
for i = 1, kC do | |
network.modules[1].weight[i]:fill(0.01*i) | |
end | |
network.modules[1].bias:fill(0) | |
-- use lena as an input | |
local lena_red = image.lena()[1] | |
local lena_byte = image.scale(lena_red,iW,iH):resize(1,iH,iW):mul(256):byte() | |
local lena_sw = image.scale(lena_red,iW,iH):resize(1,iH,iW):expand(iC,iH,iW) | |
local lena_hw = torch.repeatTensor(lena_byte,iC,1,1) | |
-- parse network | |
local dst_sw = network:forward(lena_sw) | |
local dst_hw = network:forward(lena_sw) -- replace this with new library functions! | |
-- print output | |
print('==> Precision test') | |
local precision = 5 | |
local coordinate = 20 | |
local function trunc(x) | |
return math.floor(x*math.pow(10,precision)+.5)/math.pow(10,precision) | |
end | |
for i = 1, kC do | |
local sw = trunc(dst_sw[i][coordinate][coordinate]) | |
local hw = trunc(dst_hw[i][coordinate][coordinate]) | |
local diff = trunc(math.abs(sw-hw)) | |
print('output['..i..']: ', 'CPU = ', sw, 'FPGA = ', hw, 'DIFF = ', diff) | |
end | |
-- display output | |
image.display(dst_hw) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment