Skip to content

Instantly share code, notes, and snippets.

@wcneill
Created July 30, 2020 03:42
Show Gist options
  • Save wcneill/801b8b87d4960cec8939ea65d5ed2c0e to your computer and use it in GitHub Desktop.
Save wcneill/801b8b87d4960cec8939ea65d5ed2c0e to your computer and use it in GitHub Desktop.
input_size = 50 # representing the one-hot encoded vector size
hidden_size = 100 # number of hidden nodes in the LSTM layer
n_layers = 2 # number of LSTM layers
output_size = 50 # output of 50 scores for the next character
lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True)
linear = nn.Linear(hidden_size, output_size)
# Data Flow Protocol
# 1. network input shape: (batch_size, seq_length, num_features)
# 2. LSTM output shape: (batch_size, seq_length, hidden_size)
# 3. Linear input shape: (batch_size * seq_length, hidden_size)
# 4. Linear output: (batch_size * seq_length, out_size)
x = get_batches(data)
x, hs = lstm(x, hs)
x = x.reshape(-1, hidden_size)
x = linear(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment