Skip to content

Instantly share code, notes, and snippets.

@Atcold
Last active August 29, 2015 13:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Atcold/9746909 to your computer and use it in GitHub Desktop.
Save Atcold/9746909 to your computer and use it in GitHub Desktop.
Testing comparison operators in Torch with `Double` and `Cuda` *Tensor*
--------------------------------------------------------------------------------
-- This script tests the comparison operators of Torch
--------------------------------------------------------------------------------
-- It looks like there are some troubles with CudaTensor...
-- Let's find it out :)
--------------------------------------------------------------------------------
require 'cutorch'
function printStat(tensor)
local min = tensor:min()
local max = tensor:max()
local avg = tensor:mean()
local std = tensor:std()
print('Stats\n' ..
' - min: ' .. min .. '\n' ..
' - max: ' .. max .. '\n' ..
' - avg: ' .. avg .. '\n' ..
' - std: ' .. std .. '\n')
end
function pause()
io.write('Press <Return> to continue...')
io.read()
io.write('\n')
end
a = torch.Tensor(3,3)
val = {-1,0,1}
i = 1
for _,v in ipairs(val) do
for _,u in ipairs(val) do
a:storage()[i] = v/u
i = i + 1
end
end
print('print(a)')
print(a)
print('printStat(a)')
printStat(a)
print('printStat(a:cuda())')
printStat(a:cuda())
pause()
tensorTable = {}
i = 1
for _,v in ipairs(val) do
for _,u in ipairs(val) do
tensorTable[i] = torch.Tensor(3,3):fill(v/u)
print('print(tensorTable[' .. i .. '])')
print(tensorTable[i])
print('printStat(tensorTable[' .. i .. '])')
printStat(tensorTable[i])
print('printStat(tensorTable[' .. i .. ']:cuda())')
printStat(tensorTable[i]:cuda())
i = i + 1
pause()
end
end
b = torch.Tensor(3,3):fill(0/0)
b[2][2] = 0
print('print(b)')
print(b)
print('printStat(b)')
printStat(b)
print('printStat(b:cuda())')
printStat(b:cuda())
c = torch.Tensor(3,3):fill(0/0)
c[1][1] = 0
print('print(c)')
print(c)
print('printStat(c)')
printStat(c)
print('printStat(c:cuda())')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment