Last active
March 11, 2017 10:30
-
-
Save vadimkantorov/6f2e99c44cc7c73a5fbb5a47a6802864 to your computer and use it in GitHub Desktop.
Torch routines for saving and loading acyclic objects with HDF5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Usage: | |
-- hdf5_save('/path/to/my.h5', {a = torch.Tensor(1), b = {torch.CudaTensor(2)}, 'string', c = 5, d = true}) | |
-- obj = hdf5_load('/path/to/my.h5') | |
-- obj = hdf5_load('/path/to/my.h5', {'a', 'b'}) | |
-- tensor = hdf5_load('/path/to/my.h5', {'a'}) | |
require 'hdf5' | |
function hdf5_save(path, obj) | |
local function dfs(datapath, obj, h) | |
if torch.isTensor(obj) then | |
local cpu_tensor_type = obj:type() == 'torch.CudaTensor' and 'torch.FloatTensor' or obj:type():gsub('Cuda', '') | |
h:write(datapath, obj:numel() > 0 and obj:type(cpu_tensor_type) or torch.factory(cpu_tensor_type)():resize(1):resize(0)) | |
elseif type(obj) == 'number' then | |
h:write(datapath, torch.DoubleTensor(1):fill(obj)) | |
elseif type(obj) == 'string' then | |
h:write(datapath, torch.CharTensor(torch.CharStorage():string(obj))) | |
elseif type(obj) == 'boolean' then | |
h:write(datapath, torch.IntTensor(1):fill(v and 1 or 0)) | |
elseif type(obj) == 'table' then | |
for k, v in pairs(obj) do | |
dfs(datapath .. '/' .. k, v, h) | |
end | |
else | |
error('unknown object type: ' .. type(obj) .. ', datapath: ' .. datapath) | |
end | |
end | |
local h = hdf5.open(path, 'w') | |
dfs(type(obj) == 'table' and '' or '/data', obj, h) | |
h:close() | |
end | |
function hdf5_load(path, fields) | |
local function dfs(obj) | |
if torch.isTypeOf(obj, torch.CharTensor) or torch.isTypeOf(obj, torch.ByteTensor) then | |
return obj:storage():string() | |
elseif torch.isTypeOf(obj, torch.DoubleTensor) and obj:nElement() == 1 then | |
return obj:squeeze() | |
elseif torch.isTypeOf(obj, torch.IntTensor) and obj:nElement() == 1 and (obj:squeeze() == 0 or obj:squeeze() == 1) then | |
return obj:squeeze() == 1 and true or false | |
elseif type(obj) == 'table' then | |
local res = {} | |
for k, v in pairs(obj) do | |
res[tonumber(k) and tonumber(k) or k] = dfs(v) | |
end | |
return res | |
else | |
return obj | |
end | |
end | |
local res = {} | |
local h = hdf5.open(path, 'r') | |
if fields then | |
for _, f in ipairs(type(fields) == 'table' and fields or {fields}) do | |
res[f] = nil | |
pcall(function() res[f] = h:read('/' .. f):all() end) | |
end | |
else | |
res = h:all() | |
end | |
h:close() | |
res = dfs(res) | |
return next(res, ({next(res)})[1]) and res or res[next(res)] | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment