Skip to content

Instantly share code, notes, and snippets.

@JoostvDoorn
Created November 21, 2016 19:49
Show Gist options
  • Save JoostvDoorn/d5e2787a0a307fcc126acf41c9f749bf to your computer and use it in GitHub Desktop.
Save JoostvDoorn/d5e2787a0a307fcc126acf41c9f749bf to your computer and use it in GitHub Desktop.
------------------------------------------------------------------------
--[[ CriterionMaskZero ]]--
-- Decorator that zeros err and gradInputs of the encapsulated criterion
-- if the target is zero
-- note: Use at your own risk
-- author: JoostvDoorn
------------------------------------------------------------------------
local CriterionMaskZero, parent = torch.class("nn.CriterionMaskZero", "nn.Criterion")
function CriterionMaskZero:__init(criterion)
parent.__init(self)
self.criterion = criterion
self.gradInput = torch.Tensor()
assert(torch.isTypeOf(criterion, 'nn.Criterion'))
end
function CriterionMaskZero:updateOutput(input, target)
self.idx = target:ne(0)
self.idx2 = self.idx:view(-1, 1):expandAs(input)
if self.idx:sum() > 0 then
self.target = target[self.idx]
self.input = input[self.idx2]:view(self.target:size(1), -1)
self.output = self.criterion:updateOutput(self.input, self.target)
else
-- when all samples are masked, then loss is zero (issue 128)
self.output = 0
end
return self.output
end
function CriterionMaskZero:updateGradInput(input, target)
self.gradInput:resizeAs(input):zero()
if self.idx:sum() == self.idx:nElement() then
self.gradInput:copy(self.criterion:updateGradInput(input, target))
elseif self.idx:sum() > 0 then
-- self.gradInput[self.idx2]:copy(self.criterion:updateGradInput(self.input, self.target):zero())
end
return self.gradInput
end
function CriterionMaskZero:type(type, ...)
self.zeroMask = nil
self.input = nil
self.target = nil
self.gradInput = self.gradInput:type(type)
return parent.type(self, type, ...)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment