Skip to content

Instantly share code, notes, and snippets.

@GallagherCommaJack
Last active January 21, 2022 00:59
Show Gist options
  • Save GallagherCommaJack/0321874be9911c1b38af556b628d2468 to your computer and use it in GitHub Desktop.
Save GallagherCommaJack/0321874be9911c1b38af556b628d2468 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from einops import reduce, rearrange
class DepthwiseRematConvFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input,
k1,
k2,
bias=None,
padding=0,
):
ctx.padding = padding
ctx.save_for_backward(input, k1, k2)
with torch.no_grad():
weight = torch.einsum("oi,ihw->oihw", k1, k2)
output = F.conv2d(
input,
weight,
bias,
padding=padding,
)
return output
@staticmethod
def backward(ctx, grad_output):
input, k1, k2 = ctx.saved_tensors
padding = ctx.padding
needs_weight_grad = ctx.needs_input_grad[1] or ctx.needs_input_grad[2]
grad_input = grad_k1 = grad_k2 = grad_bias = None
weight = torch.einsum("oi,ihw->oihw", k1, k2)
if ctx.needs_input_grad[0]:
grad_input = F.conv_transpose2d(grad_output, weight, padding=padding)
if needs_weight_grad:
grad_weight = F.conv2d(
rearrange(input, "b c h w -> c b h w").contiguous(),
rearrange(grad_output, "b c h w -> c b h w"),
padding=padding,
)
if ctx.needs_input_grad[1]:
grad_k1 = torch.einsum("oihw,ihw->oi", grad_weight, k2)
if ctx.needs_input_grad[2]:
grad_k2 = torch.einsum("oihw,oi->ihw", grad_weight, k1)
if ctx.needs_input_grad[3]:
grad_bias = reduce(grad_output, "b c h w -> c", "sum")
return grad_input, grad_k1, grad_k2, grad_bias, None, None, None, None
depthwise_remat_conv = DepthwiseRematConvFn.apply
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment