Skip to content

Instantly share code, notes, and snippets.

@TheLoneNut
Created February 15, 2018 15:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TheLoneNut/ef194733d7d9c7dc4a2743605f0b6753 to your computer and use it in GitHub Desktop.
Save TheLoneNut/ef194733d7d9c7dc4a2743605f0b6753 to your computer and use it in GitHub Desktop.
def create_base_network(in_dims, out_dims):
    """
    Base network to be shared.
    """
    model = Sequential()
    model.add(BatchNormalization(input_shape=in_dims))
    model.add(LSTM(512, return_sequences=True, dropout=0.2, recurrent_dropout=0.2, implementation=2))
    model.add(LSTM(512, return_sequences=False, dropout=0.2, recurrent_dropout=0.2, implementation=2))
    model.add(BatchNormalization())
    model.add(Dense(512, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dense(out_dims, activation='relu'))
    model.add(BatchNormalization())
    return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment