Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active January 25, 2023 01:00
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/7eb4095a4f3d3b5eea8adaaf4419c822 to your computer and use it in GitHub Desktop.
Save wassname/7eb4095a4f3d3b5eea8adaaf4419c822 to your computer and use it in GitHub Desktop.
pytorch Causal Conv2d
from torch.nn.modules.utils import _pair
class CausalConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
dilation = _pair(dilation)
if padding is None:
padding = [int((kernel_size[i] -1) * dilation[i]) for i in range(len(kernel_size))]
else:
padding = padding * 2
self.left_padding = _pair(padding)
super().__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=0, dilation=dilation,
groups=groups, bias=bias)
def forward(self, inputs):
inputs = F.pad(inputs, (self.left_padding[1], 0, self.left_padding[0], 0))
output = super().forward(inputs)
return output
@jbetker
Copy link

jbetker commented Jan 23, 2023

Note from a random person: this expects padding to be different from what is traditionally passed to nn.Conv2d. I believe line 11 should include:

else:
   padding = padding * 2

@wassname
Copy link
Author

That's the best kind of note! Thanks, and it's updated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment