Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Image classifier for use with FAIR Deepmask+Sharpmask
--[[----------------------------------------------------------------------------
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
Run full scene inference in sample image
Copyright (c) 2016 Surround.io. Modifications to resize images to a max
width, and output images to a designated dir with a unique timestamp name
------------------------------------------------------------------------------]]
require 'torch'
require 'cutorch'
require 'image'
require 'os'
--------------------------------------------------------------------------------
-- parse arguments
local cmd = torch.CmdLine()
cmd:text()
cmd:text('evaluate deepmask/sharpmask')
cmd:text()
cmd:argument('-model', 'path to model to load')
cmd:text('Options:')
cmd:option('-img','data/testImage.jpg' ,'path/to/test/image')
cmd:option('-gpu', 1, 'gpu device')
cmd:option('-np', 5,'number of proposals to save in test')
cmd:option('-si', -2.5, 'initial scale')
cmd:option('-sf', .5, 'final scale')
cmd:option('-ss', .5, 'scale step')
cmd:option('-dm', false, 'use DeepMask version of SharpMask')
cmd:option('-wl', 756, 'image width limit')
cmd:option('-out', '.', 'output dir')
local config = cmd:parse(arg)
--------------------------------------------------------------------------------
-- various initializations
torch.setdefaulttensortype('torch.FloatTensor')
cutorch.setDevice(config.gpu)
local coco = require 'coco'
local maskApi = coco.MaskApi
local meanstd = {mean = { 0.485, 0.456, 0.406 }, std = { 0.229, 0.224, 0.225 }}
--------------------------------------------------------------------------------
-- load moodel
paths.dofile('/root/deepmask/DeepMask.lua')
paths.dofile('/root/deepmask/SharpMask.lua')
print('| loading model file... ' .. config.model)
local m = torch.load(config.model..'/model.t7')
local model = m.model
model:inference(config.np)
model:cuda()
--------------------------------------------------------------------------------
-- create inference module
local scales = {}
for i = config.si,config.sf,config.ss do table.insert(scales,2^i) end
if torch.type(model)=='nn.DeepMask' then
paths.dofile('/root/deepmask/InferDeepMask.lua')
elseif torch.type(model)=='nn.SharpMask' then
paths.dofile('/root/deepmask/InferSharpMask.lua')
end
local infer = Infer{
np = config.np,
scales = scales,
meanstd = meanstd,
model = model,
dm = config.dm,
}
--------------------------------------------------------------------------------
-- do it
print('| start')
-- load image
local img = image.load(config.img)
local h,w = img:size(2),img:size(3)
-- keep it small
if w > config.wl then
local newW = config.wl
local newH = newW/w * h
img = image.scale(img, newW, newH)
h,w = img:size(2),img:size(3)
end
-- forward all scales
infer:forward(img)
-- get top propsals
local masks,_ = infer:getTopProps(.2,h,w)
-- save result
local res = img:clone()
maskApi.drawMasks(res, masks, 10)
local outputFile = config.out .. os.date('!%d-%b-%Y-%H:%M:%S') .. '.jpg'
image.save(string.format(outputFile, config.model), res)
print('| done')
collectgarbage()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment