Skip to content

Instantly share code, notes, and snippets.

@zredlined
Last active January 29, 2020 22:08
Show Gist options
  • Save zredlined/22bb1ad460817109786f31e304e5b7b1 to your computer and use it in GitHub Desktop.
Save zredlined/22bb1ad460817109786f31e304e5b7b1 to your computer and use it in GitHub Desktop.
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
dropout_rate = 0.2
rnn_initializer = 'glorot_uniform'
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape=[batch_size, None]),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.LSTM(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer=rnn_initializer),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.LSTM(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer=rnn_initializer),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Dense(vocab_size)
])
return model
model = build_model(
vocab_size = len(vocab),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE)
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment