Skip to content

Instantly share code, notes, and snippets.

@minhlab
Created December 11, 2015 21:43
Show Gist options
  • Save minhlab/ff4d970be146eda229fb to your computer and use it in GitHub Desktop.
Save minhlab/ff4d970be146eda229fb to your computer and use it in GitHub Desktop.
local MaskedLogSoftMax, Parent = torch.class('nn.MaskedLogSoftMax', 'nn.Module')
function MaskedLogSoftMax:__init(masks, filler)
Parent.__init(self)
self.masks = masks
self.minvals = torch.Tensor()
self.mininds = torch.LongTensor()
self.temp1 = torch.Tensor()
self.temp2 = torch.Tensor()
self.temp3 = torch.Tensor()
self.temp4 = torch.Tensor()
self.batch_masks = torch.Tensor()
end
function MaskedLogSoftMax:updateOutput(input)
local data = input[1]
local states = input[2]
self.batch_masks:index(self.masks, 1, states)
print(data)
torch.min(self.minvals, self.mininds, data, 2)
print(self.minvals)
self.temp1:add(data, -1, self.minvals:expandAs(data)):exp()
print(self.temp1)
print('mask', self.batch_masks)
self.temp2:resizeAs(data):zero():cmul(self.temp1, self.batch_masks)
print('temp2', self.temp2)
self.temp3:sum(self.temp2, 2)
print(self.temp3)
self.temp3:log()
print(self.temp3)
self.temp3:add(self.minvals)
print(self.temp3)
self.temp4:add(data, -1, self.temp3:expandAs(data))
self.output:resizeAs(data):zero():cmul(self.temp4, self.batch_masks)
return self.output
end
function MaskedLogSoftMax:updateGradInput(input, gradOutput)
self.gradInput = {gradOutput}
return self.gradInput
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment