Skip to content

Instantly share code, notes, and snippets.

@htoyryla
Last active February 13, 2018 10:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save htoyryla/49cb3ab0864d2a12f558631c7b3d87a3 to your computer and use it in GitHub Desktop.
Save htoyryla/49cb3ab0864d2a12f558631c7b3d87a3 to your computer and use it in GitHub Desktop.
-- this program takes in an image
-- and finds nc channels on a given layer
-- having the strongest activations
require 'torch'
require 'nn'
require 'image'
require 'loadcaffe'
function preprocess(img)
local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68})
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm):mul(256.0)
mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
img:add(-1, mean_pixel)
return img
end
local cmd = torch.CmdLine()
cmd:option('-image', 'examples/inputs/tubingen.jpg')
cmd:option('-image_size', 800, 'output image size')
cmd:option('-proto', 'models/VGG_ILSVRC_19_layers_deploy.prototxt')
cmd:option('-model', 'models/VGG_ILSVRC_19_layers.caffemodel')
cmd:option('-layer', 'relu4_2', 'layer for examine')
cmd:option('-nc', '10', 'number of best channels to be shown')
local params = cmd:parse(arg)
local content_image = image.load(params.image, 3)
content_image = image.scale(content_image, params.image_size, 'bilinear')
local content_image_caffe = preprocess(content_image):float()
local img = content_image_caffe:clone():float()
local cnn = loadcaffe.load(params.proto, params.model, "nn"):float()
local net = nn.Sequential()
for i = 1, #cnn do
local layer = cnn:get(i)
local typ = torch.type(layer)
local name = layer.name
print(name, typ)
net:add(layer)
if (name == params.layer) then break end
if (i == #cnn) then
print("No such layer: "..params.layer)
return
end
end
local y = net:forward(img)
local n = y:size(1)
activ = torch.Tensor(n)
for i = 1, n do
local y3 = torch.Tensor(3,y:size(2),y:size(3))
local y1 = y:clone():narrow(1,i,1)
local norm = torch.norm(y1)
activ[i] = norm
end
channels, idx = torch.sort(activ, 1, true)
for i = 1, params.nc do
print(idx[i], channels[i])
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment