Skip to content

Instantly share code, notes, and snippets.

@srossross
Created March 10, 2023 19:20
Show Gist options
  • Save srossross/fb93814f1999866671c863698ab66b41 to your computer and use it in GitHub Desktop.
Save srossross/fb93814f1999866671c863698ab66b41 to your computer and use it in GitHub Desktop.
# Define model
class NeuralNetwork(nn.Module):
def __init__(self, input_size, output_size, hidden_dim, n_layers):
super(NeuralNetwork, self).__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_size)
def init_hidden(self, batch_size):
hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
return hidden
def forward(self, x):
batch_size = x.size(0)
# Initializing hidden state for first input using method defined below
hidden = self.init_hidden(batch_size)
# Passing in the input and hidden state into the model and obtaining outputs
out, hidden = self.rnn(x, hidden)
# Reshaping the outputs such that it can be fit into the fully connected layer
out = out.contiguous().view(-1, self.hidden_dim)
out = self.fc(out)
return out, hidden
model = NeuralNetwork(50, 50, 50, 20)
print(model)
model(torch.randn(3, 50 , 50))
cmod = torch.compile(model.forward)
cmod(torch.randn(3, 50 , 50, requires_grad=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment