Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created November 2, 2020 09:34
Show Gist options
  • Save jinglescode/b349b2465f7bc874caad22cc486f187c to your computer and use it in GitHub Desktop.
Save jinglescode/b349b2465f7bc874caad22cc486f187c to your computer and use it in GitHub Desktop.
class TestConv1d(nn.Module):
def __init__(self):
super(TestConv1d, self).__init__()
self.conv = nn.Conv1d(in_channels=2, out_channels=4, kernel_size=1, groups=2, bias=False)
self.init_weights()
def forward(self, x):
return self.conv(x)
def init_weights(self):
print(self.conv.weight.shape)
self.conv.weight[0,0,0] = 2.
self.conv.weight[1,0,0] = 4.
self.conv.weight[2,0,0] = 6.
self.conv.weight[3,0,0] = 8.
in_x = torch.tensor([[[1,2,3,4,5,6],[10,20,30,40,50,60]]]).float()
print("in_x.shape", in_x.shape)
print(in_x)
net = TestConv1d()
out_y = net(in_x)
print("out_y.shape", out_y.shape)
print(out_y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment