Skip to content

Instantly share code, notes, and snippets.

@arunmallya
Created February 20, 2018 20:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arunmallya/34524996c5c5246e0106cd05743af5d1 to your computer and use it in GitHub Desktop.
Save arunmallya/34524996c5c5246e0106cd05743af5d1 to your computer and use it in GitHub Desktop.
Convolution with masking support.
class ElementWiseConv2d(nn.Module):
"""Modified conv. Do we need mask for biases too?"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
mask_init='1s', mask_scale=1e-2,
threshold_fn='binarizer', threshold=None):
super(ElementWiseConv2d, self).__init__()
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
self.threshold_fn = threshold_fn
self.mask_scale = mask_scale
self.mask_init = mask_init
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = _pair(0)
self.groups = groups
# weight and bias are no longer Parameters.
self.weight = Variable(torch.Tensor(
out_channels, in_channels // groups, *kernel_size), requires_grad=False)
if bias:
self.bias = Variable(torch.Tensor(
out_channels), requires_grad=False)
else:
self.register_parameter('bias', None)
# Initialize real-valued mask with weights.
self.mask_real = self.weight.data.new(self.weight.size())
if mask_init == '1s':
self.mask_real.fill_(mask_scale)
elif mask_init == 'uniform':
self.mask_real.uniform_(-1 * mask_scale, mask_scale)
self.mask_real = Parameter(self.mask_real)
# Initialize the thresholder.
if threshold_fn == 'binarizer':
if threshold is None:
threshold = DEFAULT_THRESHOLD
print('Calling binarizer with threshold:', threshold)
self.threshold_fn = Binarizer(threshold=threshold)
elif threshold_fn == 'ternarizer':
if threshold is None:
threshold = DEFAULT_THRESHOLD
print('Calling ternarizer with threshold:', threshold)
self.threshold_fn = Ternarizer(threshold=threshold)
def forward(self, input):
mask_thresholded = self.threshold_fn(self.mask_real)
weight_thresholded = mask_thresholded * self.weight
return F.conv2d(input, weight_thresholded, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def __repr__(self):
s = ('{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
for param in self._parameters.values():
if param is not None:
# Variables stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param._grad is not None:
param._grad.data = fn(param._grad.data)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
self.weight.data = fn(self.weight.data)
if self.bias is not None and self.bias.data is not None:
self.bias.data = fn(self.bias.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment