Skip to content

Instantly share code, notes, and snippets.

@fmassa
Created November 1, 2016 16:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fmassa/e44159d86340a8a4ee7adfbc0fba60ed to your computer and use it in GitHub Desktop.
Save fmassa/e44159d86340a8a4ee7adfbc0fba60ed to your computer and use it in GitHub Desktop.
Script to recover saved optnet models for which the tensor pointer changed after optimization
local optnet = require 'optnet'
local net = torch.load('celeba_24_G.t7')
local keys
-- the tensor pointers that were saved in the model
-- have changed and are not valid anymore.
-- In order to try to recover, let's suppose that
-- the mapping is given by the offsets in ascending
-- order.
do
local t = {}
for k,v in pairs(net.__gradParamsInfo) do
table.insert(t, {k, v.offSet})
end
local tt = torch.LongTensor(t)
local _, v = tt:select(2,2):sort(1)
keys = tt:select(2,1):index(1, v)
end
-- remap the old pointers with the new ones
local p, gp = net:parameters()
for i = 1, keys:numel() do
local ptr = torch.pointer(gp[i])
net.__gradParamsInfo[ptr] = net.__gradParamsInfo[keys[i]]
end
-- now we should be able to remove the optimization
optnet.removeOptimization(net)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment