Skip to content

Instantly share code, notes, and snippets.

@toshi-k
Last active May 3, 2016 15:23
Show Gist options
  • Save toshi-k/07ac1993b6f8b7fd47e33a9fd25b4bff to your computer and use it in GitHub Desktop.
Save toshi-k/07ac1993b6f8b7fd47e33a9fd25b4bff to your computer and use it in GitHub Desktop.
Random Feature Extractor (Torch 7)
local RandomFeatureExtractor, Parent = torch.class('nn.RandomFeatureExtractor', 'nn.Module')
function RandomFeatureExtractor:__init(inputSize, outputSize, kmin, kmax)
Parent.__init(self)
self.mask = torch.Tensor(outputSize, inputSize):zero()
for i = 1,outputSize do
local num_samp = math.random(kmin, kmax, 1)
local index_samp = torch.randperm(inputSize)
for j = 1,num_samp do
self.mask[{{i},{index_samp[j]}}] = 1
end
end
self.inputSize = inputSize
self.outputSize = outputSize
self.kmin = kmin
self.kmax = kmax
self.output = torch.Tensor()
self.gradInput = torch.Tensor()
end
function RandomFeatureExtractor:updateOutput(input)
if input:dim() == 1 then
self.output:resize(self.outputSize)
self.output:mv(self.mask, input)
else
self.batchSize = input:size(1)
self.output:resize(self.batchSize, self.outputSize)
self.output:mm(input, self.mask:t())
end
return self.output
end
function RandomFeatureExtractor:updateGradInput(input, gradOutput)
if input:dim() == 1 then
self.gradInput:resizeAs(input)
self.gradInput:mv(self.mask:t(), gradOutput)
else
self.batchSize = input:size(1)
self.gradInput:resize(self.batchSize, self.inputSize)
self.gradInput:mm(gradOutput, self.mask)
end
return self.gradInput
end
function RandomFeatureExtractor:__tostring__()
return torch.type(self) ..
string.format('(%d -> %d, kmin: %d, kmax: %d)', self.inputSize, self.outputSize, self.kmin, self.kmax)
end
--[[
<<References>>
[1] 12th solution for the Otto Group Product Classification Challenge on Kaggle.
tks0123456789
https://github.com/tks0123456789/kaggle-Otto
--]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment