Skip to content

Instantly share code, notes, and snippets.

@madagra
Created April 9, 2023 11:26
Show Gist options
  • Save madagra/d9e688ce38b1311bf6855f5acdb2254c to your computer and use it in GitHub Desktop.
Save madagra/d9e688ce38b1311bf6855f5acdb2254c to your computer and use it in GitHub Desktop.
Simple NN for fitting a function
class SimpleNN(nn.Module):
def __init__(
self,
num_hidden: int = 1,
dim_hidden: int = 1,
act: nn.Module = nn.Tanh(),
) -> None:
"""Basic neural network with linear layers and non-linear activation function
Args:
num_hidden (int, optional): The number of hidden layers in the mode
dim_hidden (int, optional): The number of neurons for each hidden layer
act (nn.Module, optional): The type of non-linear activation function to be used
"""
super().__init__()
self.layer_in = nn.Linear(1, dim_hidden)
self.layer_out = nn.Linear(dim_hidden, 1)
num_middle = num_hidden - 1
self.middle_layers = nn.ModuleList(
[nn.Linear(dim_hidden, dim_hidden) for _ in range(num_middle)]
)
self.act = act
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.act(self.layer_in(x))
for layer in self.middle_layers:
out = self.act(layer(out))
return self.layer_out(out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment