Skip to content

Instantly share code, notes, and snippets.

@deeperunderstanding
Last active June 13, 2022 13:10
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save deeperunderstanding/40caa1eff607b9ef2243f90ede43890d to your computer and use it in GitHub Desktop.
Save deeperunderstanding/40caa1eff607b9ef2243f90ede43890d to your computer and use it in GitHub Desktop.
def create_encoder(latent_dim, cat_dim, window_size, input_dim):
input_layer = Input(shape=(window_size, input_dim))
code = TimeDistributed(Dense(64, activation='linear'))(input_layer)
code = Bidirectional(LSTM(128, return_sequences=True))(code)
code = BatchNormalization()(code)
code = ELU()(code)
code = Bidirectional(LSTM(64))(code)
code = BatchNormalization()(code)
code = ELU()(code)
cat = Dense(64)(code)
cat = BatchNormalization()(cat)
cat = PReLU()(cat)
cat = Dense(cat_dim, activation='softmax')(cat)
latent_repr = Dense(64)(code)
latent_repr = BatchNormalization()(latent_repr)
latent_repr = PReLU()(latent_repr)
latent_repr = Dense(latent_dim, activation='linear')(latent_repr)
decode = Concatenate()([latent_repr, cat])
decode = RepeatVector(window_size)(decode)
decode = Bidirectional(LSTM(64, return_sequences=True))(decode)
decode = ELU()(decode)
decode = Bidirectional(LSTM(128, return_sequences=True))(decode)
decode = ELU()(decode)
decode = TimeDistributed(Dense(64))(decode)
decode = ELU()(decode)
decode = TimeDistributed(Dense(input_dim, activation='linear'))(decode)
error = Subtract()([input_layer, decode])
return Model(input_layer, [decode, latent_repr, cat, error])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment