Created
August 29, 2016 15:12
-
-
Save jhludwig/9d17d8c7abd2822f863ca8256ec3a82f to your computer and use it in GitHub Desktop.
Image classifier for use with FAIR Deepmask+Sharpmask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[---------------------------------------------------------------------------- | |
Copyright (c) 2016-present, Facebook, Inc. All rights reserved. | |
This source code is licensed under the BSD-style license found in the | |
LICENSE file in the root directory of this source tree. An additional grant | |
of patent rights can be found in the PATENTS file in the same directory. | |
Run full scene inference in sample image | |
Copyright (c) 2016 Surround.io. Modifications to resize images to a max | |
width, and output images to a designated dir with a unique timestamp name | |
------------------------------------------------------------------------------]] | |
require 'torch' | |
require 'cutorch' | |
require 'image' | |
require 'os' | |
-------------------------------------------------------------------------------- | |
-- parse arguments | |
local cmd = torch.CmdLine() | |
cmd:text() | |
cmd:text('evaluate deepmask/sharpmask') | |
cmd:text() | |
cmd:argument('-model', 'path to model to load') | |
cmd:text('Options:') | |
cmd:option('-img','data/testImage.jpg' ,'path/to/test/image') | |
cmd:option('-gpu', 1, 'gpu device') | |
cmd:option('-np', 5,'number of proposals to save in test') | |
cmd:option('-si', -2.5, 'initial scale') | |
cmd:option('-sf', .5, 'final scale') | |
cmd:option('-ss', .5, 'scale step') | |
cmd:option('-dm', false, 'use DeepMask version of SharpMask') | |
cmd:option('-wl', 756, 'image width limit') | |
cmd:option('-out', '.', 'output dir') | |
local config = cmd:parse(arg) | |
-------------------------------------------------------------------------------- | |
-- various initializations | |
torch.setdefaulttensortype('torch.FloatTensor') | |
cutorch.setDevice(config.gpu) | |
local coco = require 'coco' | |
local maskApi = coco.MaskApi | |
local meanstd = {mean = { 0.485, 0.456, 0.406 }, std = { 0.229, 0.224, 0.225 }} | |
-------------------------------------------------------------------------------- | |
-- load moodel | |
paths.dofile('/root/deepmask/DeepMask.lua') | |
paths.dofile('/root/deepmask/SharpMask.lua') | |
print('| loading model file... ' .. config.model) | |
local m = torch.load(config.model..'/model.t7') | |
local model = m.model | |
model:inference(config.np) | |
model:cuda() | |
-------------------------------------------------------------------------------- | |
-- create inference module | |
local scales = {} | |
for i = config.si,config.sf,config.ss do table.insert(scales,2^i) end | |
if torch.type(model)=='nn.DeepMask' then | |
paths.dofile('/root/deepmask/InferDeepMask.lua') | |
elseif torch.type(model)=='nn.SharpMask' then | |
paths.dofile('/root/deepmask/InferSharpMask.lua') | |
end | |
local infer = Infer{ | |
np = config.np, | |
scales = scales, | |
meanstd = meanstd, | |
model = model, | |
dm = config.dm, | |
} | |
-------------------------------------------------------------------------------- | |
-- do it | |
print('| start') | |
-- load image | |
local img = image.load(config.img) | |
local h,w = img:size(2),img:size(3) | |
-- keep it small | |
if w > config.wl then | |
local newW = config.wl | |
local newH = newW/w * h | |
img = image.scale(img, newW, newH) | |
h,w = img:size(2),img:size(3) | |
end | |
-- forward all scales | |
infer:forward(img) | |
-- get top propsals | |
local masks,_ = infer:getTopProps(.2,h,w) | |
-- save result | |
local res = img:clone() | |
maskApi.drawMasks(res, masks, 10) | |
local outputFile = config.out .. os.date('!%d-%b-%Y-%H:%M:%S') .. '.jpg' | |
image.save(string.format(outputFile, config.model), res) | |
print('| done') | |
collectgarbage() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment