Skip to content

Instantly share code, notes, and snippets.

@guillefix
Last active May 10, 2020 14:00
Show Gist options
  • Save guillefix/23bff068bdc457649b81027942873ce5 to your computer and use it in GitHub Desktop.
Save guillefix/23bff068bdc457649b81027942873ce5 to your computer and use it in GitHub Desktop.
temporal workaround to get Conv2dLocal to work in PyTorch
# coding: utf-8
# In[1]:
import math
import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.nn as nn
Module = nn.Module
import collections
from itertools import repeat
# In[2]:
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
# In[3]:
class _ConvNd(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if transposed:
self.weight = Parameter(torch.Tensor(
in_channels, out_channels // groups, *kernel_size))
else:
self.weight = Parameter(torch.Tensor(
out_channels, in_channels // groups, *kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
# In[4]:
class Conv2dLocal(Module):
def __init__(self, in_height, in_width, in_channels, out_channels,
kernel_size, stride=1, padding=0, bias=True, dilation=1):
super(Conv2dLocal, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.in_height = in_height
self.in_width = in_width
self.out_height = int(math.floor(
(in_height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1))
self.out_width = int(math.floor(
(in_width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1))
self.weight = Parameter(torch.Tensor(
self.out_height, self.out_width,
out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(
out_channels, self.out_height, self.out_width))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input):
return conv2d_local(
input, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
# In[5]:
unfold = F.unfold
# In[6]:
def conv2d_local(input, weight, bias=None, padding=0, stride=1, dilation=1):
if input.dim() != 4:
raise NotImplementedError("Input Error: Only 4D input Tensors supported (got {}D)".format(input.dim()))
if weight.dim() != 6:
# outH x outW x outC x inC x kH x kW
raise NotImplementedError("Input Error: Only 6D weight Tensors supported (got {}D)".format(weight.dim()))
outH, outW, outC, inC, kH, kW = weight.size()
kernel_size = (kH, kW)
# N x [inC * kH * kW] x [outH * outW]
cols = unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride)
cols = cols.view(cols.size(0), cols.size(1), cols.size(2), 1).permute(0, 2, 3, 1)
out = torch.matmul(cols, weight.view(outH * outW, outC, inC * kH * kW).permute(0, 2, 1))
out = out.view(cols.size(0), outH, outW, outC).permute(0, 3, 1, 2)
if bias is not None:
out = out + bias.expand_as(out)
return out
# In[8]:
# lc = Conv2dLocal(3, 3, 64, 2,3)
# In[9]:
# lc(torch.autograd.Variable(torch.randn((1,64,3,3))))
# In[43]:
# x=torch.autograd.Variable(torch.randn((64,6,6)))
# In[47]:
# lc._backend.SpatialConvolutionLocal??
# In[58]:
# from torch.nn import Conv2dLocal
# In[59]:
# Conv2dLocal??
@kunrenzhilu
Copy link

kunrenzhilu commented Apr 4, 2018

HI, why I do not have the F.unfold function? And, why this line

out = torch.matmul(cols, weight.view(outH * outW, outC, inC * kH * kW).permute(0, 2, 1))

works? I guest the cols is of 4 dimensions after permutation, of shape(N x [outH * outW] x 1 x [inC * kH * kW] ); and the weight is of 3 dimensions of after permutation, of shape ( [outH x outW] x [inC x kH x kW] x outC). If I am not misunderstanding, the two elements of the torch.matmul() should be of the same dimension?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment