Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active March 11, 2017 10:30
Show Gist options
  • Save vadimkantorov/6f2e99c44cc7c73a5fbb5a47a6802864 to your computer and use it in GitHub Desktop.
Save vadimkantorov/6f2e99c44cc7c73a5fbb5a47a6802864 to your computer and use it in GitHub Desktop.
Torch routines for saving and loading acyclic objects with HDF5
-- Usage:
-- hdf5_save('/path/to/my.h5', {a = torch.Tensor(1), b = {torch.CudaTensor(2)}, 'string', c = 5, d = true})
-- obj = hdf5_load('/path/to/my.h5')
-- obj = hdf5_load('/path/to/my.h5', {'a', 'b'})
-- tensor = hdf5_load('/path/to/my.h5', {'a'})
require 'hdf5'
function hdf5_save(path, obj)
local function dfs(datapath, obj, h)
if torch.isTensor(obj) then
local cpu_tensor_type = obj:type() == 'torch.CudaTensor' and 'torch.FloatTensor' or obj:type():gsub('Cuda', '')
h:write(datapath, obj:numel() > 0 and obj:type(cpu_tensor_type) or torch.factory(cpu_tensor_type)():resize(1):resize(0))
elseif type(obj) == 'number' then
h:write(datapath, torch.DoubleTensor(1):fill(obj))
elseif type(obj) == 'string' then
h:write(datapath, torch.CharTensor(torch.CharStorage():string(obj)))
elseif type(obj) == 'boolean' then
h:write(datapath, torch.IntTensor(1):fill(v and 1 or 0))
elseif type(obj) == 'table' then
for k, v in pairs(obj) do
dfs(datapath .. '/' .. k, v, h)
end
else
error('unknown object type: ' .. type(obj) .. ', datapath: ' .. datapath)
end
end
local h = hdf5.open(path, 'w')
dfs(type(obj) == 'table' and '' or '/data', obj, h)
h:close()
end
function hdf5_load(path, fields)
local function dfs(obj)
if torch.isTypeOf(obj, torch.CharTensor) or torch.isTypeOf(obj, torch.ByteTensor) then
return obj:storage():string()
elseif torch.isTypeOf(obj, torch.DoubleTensor) and obj:nElement() == 1 then
return obj:squeeze()
elseif torch.isTypeOf(obj, torch.IntTensor) and obj:nElement() == 1 and (obj:squeeze() == 0 or obj:squeeze() == 1) then
return obj:squeeze() == 1 and true or false
elseif type(obj) == 'table' then
local res = {}
for k, v in pairs(obj) do
res[tonumber(k) and tonumber(k) or k] = dfs(v)
end
return res
else
return obj
end
end
local res = {}
local h = hdf5.open(path, 'r')
if fields then
for _, f in ipairs(type(fields) == 'table' and fields or {fields}) do
res[f] = nil
pcall(function() res[f] = h:read('/' .. f):all() end)
end
else
res = h:all()
end
h:close()
res = dfs(res)
return next(res, ({next(res)})[1]) and res or res[next(res)]
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment