Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Created April 4, 2016 13:38
Show Gist options
  • Save InnovArul/a464ac44228db0303c36873cea7687b6 to your computer and use it in GitHub Desktop.
Save InnovArul/a464ac44228db0303c36873cea7687b6 to your computer and use it in GitHub Desktop.
if(opt.traintype == 'finetuning' and epoch == 1) then
-- find the number of elements for which the learning rates to be reduced
totalParams = 0;
isAfterCrossNeighbor = 0;
dontFreeze = 0;
for index, node in ipairs(model.modules) do
--currParams = node:getParameters()
-- since the weights of initial conv layers are shared, half the count after CrossInputNeighborhood layer
--if(torch.typename(node) == 'nn.CrossInputNeighborhood') then
-- totalParams = totalParams/2;
-- finalLearningRates[{{1, totalParams}}] = 0.01;
--end
if(torch.typename(node) == 'nn.Linear') then
print(torch.typename(node) .. ' found')
--if(lastLinearLayer == nil) then
-- finalWeightDecays[{{1, totalParams}}] = 0;
-- node.accGradParameters = function(self,i,o) end
--print(torch.typename(node) .. ' gradient not accumulated')
--end
dontFreeze = 1;
lastLinearLayer = node;
node:reset()
end
--freeze the parameters of conv layer
if(dontFreeze == 0) then
--node.updateGradInput = function(self,i,o) end -- for the gradInput
node.accGradParameters = function(self,i,o) end -- for freezing the parameters
end
-- check if the layer has parameters
--if(#currParams:size() ~= 0) then
-- currSize = #currParams
-- if the layer is not fully connected, set the learning rate as 0.01
-- if(torch.typename(node) == 'nn.Linear') then
-- linearLayerParams = linearLayerParams + currSize[1];
-- finalWeightDecays[{{totalParams + 1, totalParams + currSize[1]}}] = 5e-4;
-- node:reset()
-- end
-- totalParams = totalParams + currSize[1];
-- print(currSize[1] .. ', total: ', totalParams)
--end
end
--print(parameters:size())
--print('total params: '.. totalParams .. ', linear layer params: ' .. linearLayerParams)
--print(finalLearningRates)
--total = torch.sum(finalWeightDecays:eq(5e-4))
--print('total decays ' .. total)
--total = torch.sum(finalLearningRates:eq(1))
--print('total learning rates ' .. total)
--reset last linear layer
--lastLinearLayer:reset();
--total = torch.sum(parameters:eq())
--print('total 0s ' .. total)
--io.read()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment