Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created February 24, 2020 17:20
Show Gist options
  • Save grey-area/0e9ad94515087facadd704e64f6392d8 to your computer and use it in GitHub Desktop.
Save grey-area/0e9ad94515087facadd704e64f6392d8 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class MirroredConv1d(nn.Module):
def __init__(self, **kwargs):
super(MirroredConv1d, self).__init__()
if kwargs['out_channels'] % 2 == 1:
raise ValueError('Number of output channels must be even')
kwargs['out_channels'] //= 2
self.conv = nn.Conv1d(**kwargs)
def forward(self, x):
x1 = self.conv(x)
x2 = torch.flip(
self.conv(
torch.flip(x, dims=(2,))
),
dims=(2,)
)
return torch.cat((x1, x2), dim=1)
if __name__ == '__main__':
layer = MirroredConv1d(
in_channels=3,
out_channels=10,
kernel_size=3
)
x = torch.zeros(7, 3, 128)
y = layer(x)
print(y.shape) # (7, 10, 126)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment