Skip to content

Instantly share code, notes, and snippets.

@karandwivedi42
Last active August 5, 2016 12:09
Show Gist options
  • Save karandwivedi42/4d217ff054daf09c93b83096093ac8a1 to your computer and use it in GitHub Desktop.
Save karandwivedi42/4d217ff054daf09c93b83096093ac8a1 to your computer and use it in GitHub Desktop.
require 'image'
-- Image Transformations
local M = {}
M.ColorNormalize = function()
return function(img)
img = img:clone()
local x = meanstd
for i=1,3 do
img[i]:add(0.5)
img[i]:div(10)
end
return img
end
end
M.Scale = function(size)
return function(input)
local w, h = input:size(3), input:size(2)
if (w <= h and w == size) or (h <= w and h == size) then
return input
end
if w < h then
return image.scale(input, size, h/w * size)
else
return image.scale(input, w/h * size, size)
end
end
end
M.HorizontalFlip = function(prob)
return function(input)
if torch.uniform() < prob then
input = image.hflip(input)
end
return input
end
end
M.RandomCrop = function(size)
return function(input)
local w, h = input:size(3), input:size(2)
if w == size and h == size then
return input
end
local x1, y1 = torch.random(0, w - size), torch.random(0, h - size)
local out = image.crop(input, x1, y1, x1 + size, y1 + size)
assert(out:size(2) == size and out:size(3) == size, 'wrong crop size')
return out
end
end
-------------------------------------
image.save('temp.jpg',torch.randn(3,256,256))
local tnt = require 'torchnet'
local batchSize = 256
local function getIterator()
return tnt.ParallelDatasetIterator{
nthread = 4,
init = function()
require 'torchnet'
require 'image'
imtransform = M
end,
closure = function()
local list = tnt.ListDataset{
list = torch.range(1,batchSize*500):long(),
load = function(x)
return {
input = image.load('temp.jpg'):float(),
target = torch.LongTensor{x},
}
end,
}:transform{
input = tnt.transform.compose{
imtransform.Scale(256),
imtransform.RandomCrop(224),
imtransform.ColorNormalize(),
imtransform.HorizontalFlip(0.5),
}
}:batch(batchSize,'skip-last')
return list
end,
}
end
local iter = getIterator()
timer = torch.Timer()
timer:reset()
for x in iter() do
print(timer:time().real)
end
print(timer:time().real)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment