-
-
Save ProGamerGov/18b90683816e8e82795597afa1d6c96c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Experimental mask feature | |
cmd:option('-content_mask', '', 'The content mask') | |
cmd:option('-style_mask', '', 'The style mask') | |
local content_mask = image.load(params.content_mask, 3) | |
content_mask = image.scale(content_mask, params.image_size, 'bilinear') | |
local style_mask = image.load(params.style_mask, 3) | |
style_mask = image.scale(style_mask, params.image_size, 'bilinear') | |
local mask_regions = {'blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple'} | |
local content_mask_regions, style_mask_regions = {}, {} | |
for j = 1, #mask_regions do | |
local content_mask_region_setup = ExtractMask(content_mask, mask_regions[j]) | |
table.insert(content_mask_regions, content_mask_region_setup) | |
local style_mask_region_setup = ExtractMask(style_mask, mask_regions[j]) | |
table.insert(style_mask_regions, style_mask_region_setup) | |
end | |
--------------------------------------------------------------------------- | |
local is_conv = (layer_type == 'nn.SpatialConvolution' or layer_type == 'cudnn.SpatialConvolution') | |
if is_pooling then | |
for k = 1, #mask_regions do | |
content_mask_regions[k] = image.scale(content_mask_regions[k], math.ceil(content_mask_regions[k]:size(2)/2), math.ceil(content_mask_regions[k]:size(1)/2)) | |
style_mask_regions[k] = image.scale(style_mask_regions[k], math.ceil(style_mask_regions[k]:size(2)/2), math.ceil(style_mask_regions[k]:size(1)/2)) | |
end | |
elseif is_conv then | |
local sap = nn.SpatialAveragePooling(3,3,1,1,1,1):float() | |
for k = 1, #mask_regions do | |
content_mask_regions[k] = sap:forward(content_mask_regions[k]:repeatTensor(1,1,1))[1]:clone() | |
style_mask_regionss[k] = sap:forward(style_mask_regions[k]:repeatTensor(1,1,1))[1]:clone() | |
end | |
end | |
content_mask_regions = deepcopy(content_mask_regions) | |
--------------------------------------------------------------------------- | |
function deepcopy(orig) | |
local orig_type = type(orig) | |
local copy | |
if orig_type == 'table' then | |
copy = {} | |
for orig_key, orig_value in next, orig, nil do | |
copy[deepcopy(orig_key)] = deepcopy(orig_value) | |
end | |
setmetatable(copy, deepcopy(getmetatable(orig))) | |
else -- number, string, boolean, etc | |
copy = orig | |
end | |
return copy | |
end | |
--------------------------------------------------------------------------- | |
-- Extract the various mask regions based on color | |
function ExtractMask(seg, color) | |
local mask = nil | |
if color == 'green' then | |
mask = torch.lt(seg[1], 0.1) | |
mask:cmul(torch.gt(seg[2], 1-0.1)) | |
mask:cmul(torch.lt(seg[3], 0.1)) | |
elseif color == 'black' then | |
mask = torch.lt(seg[1], 0.1) | |
mask:cmul(torch.lt(seg[2], 0.1)) | |
mask:cmul(torch.lt(seg[3], 0.1)) | |
elseif color == 'white' then | |
mask = torch.gt(seg[1], 1-0.1) | |
mask:cmul(torch.gt(seg[2], 1-0.1)) | |
mask:cmul(torch.gt(seg[3], 1-0.1)) | |
elseif color == 'red' then | |
mask = torch.gt(seg[1], 1-0.1) | |
mask:cmul(torch.lt(seg[2], 0.1)) | |
mask:cmul(torch.lt(seg[3], 0.1)) | |
elseif color == 'blue' then | |
mask = torch.lt(seg[1], 0.1) | |
mask:cmul(torch.lt(seg[2], 0.1)) | |
mask:cmul(torch.gt(seg[3], 1-0.1)) | |
elseif color == 'yellow' then | |
mask = torch.gt(seg[1], 1-0.1) | |
mask:cmul(torch.gt(seg[2], 1-0.1)) | |
mask:cmul(torch.lt(seg[3], 0.1)) | |
elseif color == 'grey' then | |
mask = torch.cmul(torch.gt(seg[1], 0.5-0.1), torch.lt(seg[1], 0.5+0.1)) | |
mask:cmul(torch.cmul(torch.gt(seg[2], 0.5-0.1), torch.lt(seg[2], 0.5+0.1))) | |
mask:cmul(torch.cmul(torch.gt(seg[3], 0.5-0.1), torch.lt(seg[3], 0.5+0.1))) | |
elseif color == 'lightblue' then | |
mask = torch.lt(seg[1], 0.1) | |
mask:cmul(torch.gt(seg[2], 1-0.1)) | |
mask:cmul(torch.gt(seg[3], 1-0.1)) | |
elseif color == 'purple' then | |
mask = torch.gt(seg[1], 1-0.1) | |
mask:cmul(torch.lt(seg[2], 0.1)) | |
mask:cmul(torch.gt(seg[3], 1-0.1)) | |
else | |
error('Mask color not recognized') | |
end | |
return mask:float() | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment