Skip to content

Instantly share code, notes, and snippets.

@hazdzz
Created May 2, 2021 05:31
Show Gist options
  • Save hazdzz/b1d16b2fff3fc7e2eb4f80ffe82f0b96 to your computer and use it in GitHub Desktop.
Save hazdzz/b1d16b2fff3fc7e2eb4f80ffe82f0b96 to your computer and use it in GitHub Desktop.
Dilated Causal Convolution
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalConv1d(nn.Conv1d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True):
if enable_padding == True:
self.__padding = (kernel_size - 1) * dilation
else:
self.__padding = 0
super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=self.__padding, dilation=dilation, groups=groups, bias=bias)
def forward(self, input):
result = super(CausalConv1d, self).forward(input)
if self.__padding != 0:
return result[: , : , : -self.__padding]
return result
class CausalConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True):
kernel_size = nn.modules.utils._pair(kernel_size)
stride = nn.modules.utils._pair(stride)
dilation = nn.modules.utils._pair(dilation)
if enable_padding == True:
self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]
else:
self.__padding = 0
self.left_padding = nn.modules.utils._pair(self.__padding)
super(CausalConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias)
def forward(self, input):
if self.__padding != 0:
input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0))
result = super(CausalConv2d, self).forward(input)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment