Skip to content

Instantly share code, notes, and snippets.

@guy4261
Created May 18, 2015 14:54
Show Gist options
  • Save guy4261/7d6af7292114c6f44dc2 to your computer and use it in GitHub Desktop.
Save guy4261/7d6af7292114c6f44dc2 to your computer and use it in GitHub Desktop.
Implementing an iterator that goes through the pixels in the torch mnist dataset
train = torch.load('mnist.t7/train_32x32.t7', 'ascii')
train.data = train.data:type(torch.getdefaulttensortype())
function pixels(dataset)
local dimensions = (#dataset.data):totable()
local data = dataset.data
local d1 = 1
local d2 = 1
local d3 = 1
local d4 = 1
return function()
while (d1 <= dimensions[1]) do
while (d2 <= dimensions[2]) do
while (d3 <= dimensions[3]) do
while (d4 <= dimensions[4]) do
cur = data[d1][d2][d3][d4]
d4 = d4 + 1
return cur
end
d3 = d3 + 1
d4 = 1
end
d2 = d2 + 1
d3 = 1
end
d1 = d1 + 1
d2 = 1
end
end
end
sum = 0
for pixel in pixels(train) do
sum = sum + pixel
end
print(sum)
print(train.data:sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment