Skip to content

Instantly share code, notes, and snippets.

@imenurok
Created August 24, 2017 07:52
Show Gist options
  • Save imenurok/a1274f156c8dc4f890e4099512047f2a to your computer and use it in GitHub Desktop.
Save imenurok/a1274f156c8dc4f890e4099512047f2a to your computer and use it in GitHub Desktop.
local PReLU, parent = torch.class('nn.NG','nn.Module')
function PReLU:__init(nOutputPlane)
parent.__init(self)
-- if no argument provided, use shared model (weight is scalar)
self.nOutputPlane = nOutputPlane or 0
self.weight = torch.Tensor(nOutputPlane or 1):fill(0)
self.gradWeight = torch.Tensor(nOutputPlane or 1)
self.buf = torch.Tensor()
self.buf2 = torch.Tensor()
end
function PReLU:updateOutput(input)
self.buf:resizeAs(input):copy(self.weight:view(1,self.nOutputPlane,1,1):expandAs(input))
self.output:resizeAs(input):copy(input)
self.output:cmax(self.buf)
return self.output
end
function PReLU:updateGradInput(input, gradOutput)
if self.gradInput then
self.gradInput:resizeAs(input):copy(gradOutput)
self.gradInput[torch.lt(input,self.buf)]=0
return self.gradInput
end
end
function PReLU:accGradParameters(input, gradOutput, scale)
self.buf2:resizeAs(input):copy(gradOutput)
self.buf2[torch.gt(input,self.buf)]=0
self.gradWeight:add(torch.mul(torch.sum(torch.sum(torch.sum(self.buf2,1),3),4):view(self.nOutputPlane),scale or 1))
return self.gradWeight
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment