Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active April 7, 2017 20:12
Show Gist options
  • Save ProGamerGov/18b90683816e8e82795597afa1d6c96c to your computer and use it in GitHub Desktop.
Save ProGamerGov/18b90683816e8e82795597afa1d6c96c to your computer and use it in GitHub Desktop.
-- Experimental mask feature
cmd:option('-content_mask', '', 'The content mask')
cmd:option('-style_mask', '', 'The style mask')
local content_mask = image.load(params.content_mask, 3)
content_mask = image.scale(content_mask, params.image_size, 'bilinear')
local style_mask = image.load(params.style_mask, 3)
style_mask = image.scale(style_mask, params.image_size, 'bilinear')
local mask_regions = {'blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple'}
local content_mask_regions, style_mask_regions = {}, {}
for j = 1, #mask_regions do
local content_mask_region_setup = ExtractMask(content_mask, mask_regions[j])
table.insert(content_mask_regions, content_mask_region_setup)
local style_mask_region_setup = ExtractMask(style_mask, mask_regions[j])
table.insert(style_mask_regions, style_mask_region_setup)
end
---------------------------------------------------------------------------
local is_conv = (layer_type == 'nn.SpatialConvolution' or layer_type == 'cudnn.SpatialConvolution')
if is_pooling then
for k = 1, #mask_regions do
content_mask_regions[k] = image.scale(content_mask_regions[k], math.ceil(content_mask_regions[k]:size(2)/2), math.ceil(content_mask_regions[k]:size(1)/2))
style_mask_regions[k] = image.scale(style_mask_regions[k], math.ceil(style_mask_regions[k]:size(2)/2), math.ceil(style_mask_regions[k]:size(1)/2))
end
elseif is_conv then
local sap = nn.SpatialAveragePooling(3,3,1,1,1,1):float()
for k = 1, #mask_regions do
content_mask_regions[k] = sap:forward(content_mask_regions[k]:repeatTensor(1,1,1))[1]:clone()
style_mask_regionss[k] = sap:forward(style_mask_regions[k]:repeatTensor(1,1,1))[1]:clone()
end
end
content_mask_regions = deepcopy(content_mask_regions)
---------------------------------------------------------------------------
function deepcopy(orig)
local orig_type = type(orig)
local copy
if orig_type == 'table' then
copy = {}
for orig_key, orig_value in next, orig, nil do
copy[deepcopy(orig_key)] = deepcopy(orig_value)
end
setmetatable(copy, deepcopy(getmetatable(orig)))
else -- number, string, boolean, etc
copy = orig
end
return copy
end
---------------------------------------------------------------------------
-- Extract the various mask regions based on color
function ExtractMask(seg, color)
local mask = nil
if color == 'green' then
mask = torch.lt(seg[1], 0.1)
mask:cmul(torch.gt(seg[2], 1-0.1))
mask:cmul(torch.lt(seg[3], 0.1))
elseif color == 'black' then
mask = torch.lt(seg[1], 0.1)
mask:cmul(torch.lt(seg[2], 0.1))
mask:cmul(torch.lt(seg[3], 0.1))
elseif color == 'white' then
mask = torch.gt(seg[1], 1-0.1)
mask:cmul(torch.gt(seg[2], 1-0.1))
mask:cmul(torch.gt(seg[3], 1-0.1))
elseif color == 'red' then
mask = torch.gt(seg[1], 1-0.1)
mask:cmul(torch.lt(seg[2], 0.1))
mask:cmul(torch.lt(seg[3], 0.1))
elseif color == 'blue' then
mask = torch.lt(seg[1], 0.1)
mask:cmul(torch.lt(seg[2], 0.1))
mask:cmul(torch.gt(seg[3], 1-0.1))
elseif color == 'yellow' then
mask = torch.gt(seg[1], 1-0.1)
mask:cmul(torch.gt(seg[2], 1-0.1))
mask:cmul(torch.lt(seg[3], 0.1))
elseif color == 'grey' then
mask = torch.cmul(torch.gt(seg[1], 0.5-0.1), torch.lt(seg[1], 0.5+0.1))
mask:cmul(torch.cmul(torch.gt(seg[2], 0.5-0.1), torch.lt(seg[2], 0.5+0.1)))
mask:cmul(torch.cmul(torch.gt(seg[3], 0.5-0.1), torch.lt(seg[3], 0.5+0.1)))
elseif color == 'lightblue' then
mask = torch.lt(seg[1], 0.1)
mask:cmul(torch.gt(seg[2], 1-0.1))
mask:cmul(torch.gt(seg[3], 1-0.1))
elseif color == 'purple' then
mask = torch.gt(seg[1], 1-0.1)
mask:cmul(torch.lt(seg[2], 0.1))
mask:cmul(torch.gt(seg[3], 1-0.1))
else
error('Mask color not recognized')
end
return mask:float()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment