-
-
Save amankharwal/a76c7d8d059b78d8a9a89cda693ff53c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_model(max_sequence_len, total_words): | |
input_len = max_sequence_len - 1 | |
model = Sequential() | |
# Add Input Embedding Layer | |
model.add(Embedding(total_words, 10, input_length=input_len)) | |
# Add Hidden Layer 1 - LSTM Layer | |
model.add(LSTM(100)) | |
model.add(Dropout(0.1)) | |
# Add Output Layer | |
model.add(Dense(total_words, activation='softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer='adam') | |
return model | |
model = create_model(max_sequence_len, total_words) | |
model.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment