Skip to content

Instantly share code, notes, and snippets.

@daquexian
Last active February 6, 2022 16:20
Show Gist options
  • Save daquexian/7db1e7f1e0a92ab13ac1ad028233a9eb to your computer and use it in GitHub Desktop.
Save daquexian/7db1e7f1e0a92ab13ac1ad028233a9eb to your computer and use it in GitHub Desktop.
A correct and dabnn-compatible binary convolution PyTorch implementation
# A correct and dabnn-compatible PyTorch implementation of binary convolutions.
# It consists of a implementation of the binary convolution itself, and the way
# to make the implementation both ONNX- and dabnn-compatible
# 1. The input of binary convolutions should only be +1/-1, so we pad -1 instead
# of 0 by a explicit pad operation.
# 2. Since PyTorch doesn't support exporting Sign ONNX operator (until
# https://github.com/pytorch/pytorch/pull/20470 gets merged), we perform sign
# operation on input and weight by directly accessing the `data`
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function
class SignSTE(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input = input.sign()
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
mask = input.ge(-1) & input.le(1)
grad_input = torch.where(
mask, grad_output, torch.zeros_like(grad_output))
return grad_input
class SignWeight(Function):
@staticmethod
def forward(ctx, input):
input = input.sign()
return input
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.new_empty(grad_output.size())
grad_input.copy_(grad_output)
return grad_input
class BinaryConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(BinaryConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias)
def forward(self, input):
if self.training:
input = SignSTE.apply(input)
self.weight_bin_tensor = SignWeight.apply(self.weight)
else:
# We clone the input here because it causes unexpected behaviors
# to edit the data of `input` tensor.
input = input.clone()
input.data = input.sign()
# Even though there is a UserWarning here, we have to use `new_tensor`
# rather than the "recommended" way
self.weight_bin_tensor = self.weight.new_tensor(self.weight.sign())
# 1. The input of binary convolution shoule be only +1 or -1,
# so instead of padding 0 automatically, we need pad -1 by ourselves
# 2. `padding` of nn.Conv2d is always a tuple of (padH, padW),
# while the parameter of F.pad should be (padLeft, padRight, padTop, padBottom)
input = F.pad(input, (self.padding[0], self.padding[0],
self.padding[1], self.padding[1]), mode='constant', value=-1)
out = F.conv2d(input, self.weight_bin_tensor, self.bias, self.stride,
0, self.dilation, self.groups)
return out
@mmbejani
Copy link

mmbejani commented Feb 6, 2022

Do you benchmark it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment