Skip to content

Instantly share code, notes, and snippets.

@lebedov
Created September 16, 2016 13:37
Show Gist options
  • Save lebedov/cb26d28ef9bbb142c393cbf00c1c2fe9 to your computer and use it in GitHub Desktop.
Save lebedov/cb26d28ef9bbb142c393cbf00c1c2fe9 to your computer and use it in GitHub Desktop.
Class for passing data to Torch's nn.StochasticGradient
-- Use the class rock to create a Dataset class
-- that can be used by nn.StochasticGradient in Torch
local class = require 'class'
local Dataset = class('Dataset')
function Dataset:__init(inputs, labels)
self.inputs = inputs
self.labels = labels
end
function Dataset:size()
return self.inputs:size()[1]
end
-- Try looking keys up in original __index before
-- treating the key as an index into self.data
local t = Dataset.__index
Dataset.__index = function (self, k)
if t[k] == nil then
return {self.inputs[k], self.labels[k]}
else
return t[k]
end
end
return {Dataset = Dataset}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment