Skip to content

Instantly share code, notes, and snippets.

@pbnsilva
Created April 12, 2020 17:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pbnsilva/77872b9f06bcd347afe75d0e2884d6b9 to your computer and use it in GitHub Desktop.
Save pbnsilva/77872b9f06bcd347afe75d0e2884d6b9 to your computer and use it in GitHub Desktop.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
filters_start = 32
layer_filters = filters_start
filters_growth = 32
strides_start = 1
strides_end = 2
depth = 4
n_blocks = 6
n_channels = 1
input_shape = (n_channels, 33, 570)
layers = []
for block in range(n_blocks):
if block == 0:
provide_input = True
else:
provide_input = False
layers.append(Conv2dBlock(depth,
layer_filters,
filters_growth,
strides_start, strides_end,
input_shape,
first_layer=provide_input))
layer_filters += filters_growth
layers.append(View((-1, 9, 224)))
layers.append(LambdaLayer(lambda x: torch.mean(x, axis=1)))
layers.append(nn.Linear(224, 4))
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.net(x)
x = torch.sigmoid(x)
return x
@Franco7Scala
Copy link

Hi Pedro Silva, I would like to ask how you define the View in layers.append(View((-1, 9, 224))). I do appreciate your reply. Best Wishes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment