Skip to content

Instantly share code, notes, and snippets.

@pbnsilva
Created April 12, 2020 17:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pbnsilva/77872b9f06bcd347afe75d0e2884d6b9 to your computer and use it in GitHub Desktop.
Save pbnsilva/77872b9f06bcd347afe75d0e2884d6b9 to your computer and use it in GitHub Desktop.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
filters_start = 32
layer_filters = filters_start
filters_growth = 32
strides_start = 1
strides_end = 2
depth = 4
n_blocks = 6
n_channels = 1
input_shape = (n_channels, 33, 570)
layers = []
for block in range(n_blocks):
if block == 0:
provide_input = True
else:
provide_input = False
layers.append(Conv2dBlock(depth,
layer_filters,
filters_growth,
strides_start, strides_end,
input_shape,
first_layer=provide_input))
layer_filters += filters_growth
layers.append(View((-1, 9, 224)))
layers.append(LambdaLayer(lambda x: torch.mean(x, axis=1)))
layers.append(nn.Linear(224, 4))
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.net(x)
x = torch.sigmoid(x)
return x
@DrPingLu
Copy link

Hi Pedro Silva, I would like to ask how you define the LambdaLayer in layers.append(LambdaLayer(lambda x: torch.mean(x, axis=1))). I do appreciate your reply. Best Wishes!

@pbnsilva
Copy link
Author

Hi Pedro Silva, I would like to ask how you define the LambdaLayer in layers.append(LambdaLayer(lambda x: torch.mean(x, axis=1))). I do appreciate your reply. Best Wishes!

You can define it like this:

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

@DrPingLu
Copy link

Thanks for your reply, Pedro Silva! I also found we need to change axis to dim in this code layers.append(LambdaLayer(lambda x: torch.mean(x, dim=1))).

@itazLim
Copy link

itazLim commented Aug 17, 2021

Hello pbnsilva~!
1d_ecg7_net.py of line 30, View( ) functions, How should I write this part?

@Franco7Scala
Copy link

Hi Pedro Silva, I would like to ask how you define the View in layers.append(View((-1, 9, 224))). I do appreciate your reply. Best Wishes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment