Skip to content

Instantly share code, notes, and snippets.

@MarvinTeichmann
Forked from halochou/conv2dlocal.py
Created November 3, 2017 18:22
Show Gist options
  • Save MarvinTeichmann/fc895fa1a5f1d388bc4a9a032c005443 to your computer and use it in GitHub Desktop.
Save MarvinTeichmann/fc895fa1a5f1d388bc4a9a032c005443 to your computer and use it in GitHub Desktop.
A PyTorch wrap for SpatialConvolutionLocal
import torch
from torch.autograd import Function
from torch._thnn import type2backend
from torch.nn.modules.utils import _pair
update_output_name = 'SpatialConvolutionLocal_updateOutput'
grad_input_name = 'SpatialConvolutionLocal_updateGradInput'
grad_params_name = 'SpatialConvolutionLocal_accGradParameters'
class Conv2dLocal(Function):
def __init__(self, stride, padding, dilation, iW, iH, oW, oH):
super(Conv2dLocal, self).__init__()
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.iW = iW
self.iH = iH
self.oW = oW
self.oH = oH
def forward(self, input, weight):
self.save_for_backward(input, weight)
N, iC, iH, iW = input.size()
oH, oW, oC, _, kH, kW = weight.size()
bias = input.new(oC, oH, oW).zero_()
self._buffs = [input.new(), input.new()]
output = input.new(N, oC, oH, oW)
finput = self._buffs[0]
fgrad_input = self._buffs[1]
dH, dW = self.stride
padH, padW = self.padding
backend = type2backend[type(input)]
getattr(backend, update_output_name)(backend.library_state,
input, output,
weight, bias,
finput, fgrad_input,
kW, kH, dW, dH, padW, padH,
iW, iH, oW, oH)
return output
def backward(self, grad_output):
input, weight = self.saved_tensors
N, iC, iH, iW = input.size()
oH, oW, oC, _, kH, kW = weight.size()
dH, dW = self.stride
padH, padW = self.padding
finput = self._buffs[0]
fgrad_input = self._buffs[1]
bias = input.new(oC, oH, oW).zero_()
backend = type2backend[type(input)]
grad_input = None
if self.needs_input_grad[0]:
grad_input = input.new().resize_as_(input)
getattr(backend, grad_input_name)(backend.library_state,
input, grad_output, grad_input,
weight, finput, fgrad_input,
kW, kH, dW, dH, padW, padH,
iW, iH, oW, oH, 1.0)
grad_weight, grad_bias = (None, None)
if any(self.needs_input_grad[1:]):
grad_weight = weight.new().resize_as_(weight).zero_()
grad_bias = bias.new().resize_as_(bias).zero_()
getattr(backend, grad_params_name)(backend.library_state,
input, grad_output, grad_weight, grad_bias,
finput, fgrad_input,
kW, kH, dW, dH, padW, padH,
iW, iH, oW, oH, 1.0)
return grad_input, grad_weight#, grad_bias
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment