Created
June 21, 2016 16:25
-
-
Save szagoruyko/40947cd9a554e906f75fe1f7cef66f66 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local tablex = require 'pl.tablex' | |
require 'cunn' | |
require 'cudnn' | |
local utils = dofile '/opt/projects/coco/fastrcnn/models/model_utils.lua' | |
dofile'/home/zagoruys/projects/cifar2.torch/augmentation.lua' | |
local net = torch.load'./model_35.t7' | |
cudnn.convert(net, nn) | |
net:remove(43) | |
print(net) | |
-- net:get(41):resetSize(-1):setNumInputDims(3) | |
--net:get(2):remove(11) | |
--net:get(2):add(nn.SoftMax()) | |
local orig_classes = torch.load'./classes.t7' | |
local target_classes = torch.load'./target_classes.t7' | |
local function fixClassOrdering(layer, orig_classes, target_classes) | |
local mapping = torch.LongTensor(#orig_classes) | |
for i,v in ipairs(orig_classes) do | |
mapping[i] = tablex.find(target_classes, v) | |
end | |
layer.weight:copy(layer.weight:index(1,mapping)) | |
layer.bias:copy(layer.bias:index(1,mapping)) | |
end | |
fixClassOrdering(net:get(42), target_classes, orig_classes) | |
net.classes = target_classes | |
net.transform = { | |
mean = { | |
0.48462227599918, | |
0.45624044862054, | |
0.40588363755159, | |
}, | |
std = { | |
0.22889466674951, | |
0.22446679341259, | |
0.22495548344775, | |
} | |
} | |
local function saveNoGrad(name, net) | |
for k,v in ipairs(net:listModules()) do | |
v.gradWeight = nil | |
v.gradBias = nil | |
end | |
local ts = { | |
model = net:float(), | |
unpack = function(self) | |
for k,v in ipairs(self.model:listModules()) do | |
if v.weight and not v.gradWeight then | |
v.gradWeight = v.weight:clone() | |
v.gradBias = v.bias:clone() | |
end | |
end | |
return self.model | |
end | |
} | |
torch.save(name, ts) | |
end | |
saveNoGrad('nin_bn_final.t7', net) | |
net:cuda() | |
utils.incorporateBNtoConvAndLinear(net) | |
saveNoGrad('nin_nobn_final.t7', net:float()) | |
-- print(net) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment