Skip to content

Instantly share code, notes, and snippets.

@soumith
Created July 7, 2015 14:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/fb1968f617808134b483 to your computer and use it in GitHub Desktop.
Save soumith/fb1968f617808134b483 to your computer and use it in GitHub Desktop.
require 'cudnn'
local SpatialConvolutionUpsample, parent = torch.class('cudnn.SpatialConvolutionUpsample','cudnn.SpatialConvolution')
function SpatialConvolutionUpsample:__init(nInputPlane, nOutputPlane, kW, kH, factor, groups)
factor = factor or 2
assert(kW and kH and nInputPlane and nOutputPlane)
assert(kW % 2 == 1, 'kW has to be odd')
assert(kH % 2 == 1, 'kH has to be odd')
self.factor = factor
self.kW = kW
self.kH = kH
self.nInputPlaneU = nInputPlane
self.nOutputPlaneU = nOutputPlane
parent.__init(self, nInputPlane, nOutputPlane * factor * factor,
kW, kH, 1, 1, (kW-1)/2, (kH-1)/2, groups)
end
function SpatialConvolutionUpsample:updateOutput(input)
self.output = parent.updateOutput(self, input)
if input:dim() == 4 then
self.h = input:size(3)
self.w = input:size(4)
self.output = self.output:view(input:size(1), self.nOutputPlaneU, self.h*self.factor, self.w*self.factor)
else
self.h = input:size(2)
self.w = input:size(3)
self.output = self.output:view(self.nOutputPlaneU, self.h*self.factor, self.w*self.factor)
end
return self.output
end
function SpatialConvolutionUpsample:updateGradInput(input, gradOutput)
if input:dim() == 4 then
gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w)
else
gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w)
end
self.gradInput = parent.updateGradInput(self, input, gradOutput)
return self.gradInput
end
function SpatialConvolutionUpsample:accGradParameters(input, gradOutput, scale)
if input:dim() == 4 then
gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w)
else
gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w)
end
parent.accGradParameters(self, input, gradOutput, scale)
end
function SpatialConvolutionUpsample:accUpdateGradParameters(input, gradOutput, scale)
if input:dim() == 4 then
gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w)
else
gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w)
end
parent.accUpdateGradParameters(self, input, gradOutput, scale)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment