Created
July 14, 2020 08:38
-
-
Save aletheia/849d79c3ced89c364c57dcc1929db594 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self,x): | |
'''Forward pass, it is equal to PyTorch forward method. Here network computational graph is built | |
Parameters: | |
x (Tensor): A Tensor containing the input batch of the network | |
Returns: | |
An one dimensional Tensor with probability array for each input image | |
''' | |
x=self.conv_layer_1(x) | |
x=self.conv_layer_2(x) | |
x=self.dropout1(x) | |
x=torch.relu(self.fully_connected_1(x.view(x.size(0),-1))) | |
x=F.leaky_relu(self.dropout2(x)) | |
return F.softmax(self.fully_connected_2(x), dim=1) | |
def configure_optimizers(self): | |
''' | |
Returns: | |
(Optimizer): Adam optimizer tuned wit model parameters | |
''' | |
return torch.optim.Adam(self.parameters()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment