Skip to content

Instantly share code, notes, and snippets.

@micaleel
Created March 20, 2021 13:21
Show Gist options
  • Save micaleel/004d4265f38aadb21a08dd98a8f12865 to your computer and use it in GitHub Desktop.
Save micaleel/004d4265f38aadb21a08dd98a8f12865 to your computer and use it in GitHub Desktop.
regs = regs or [0, 0]
user_input = Input(shape=(1,), dtype="int32", name="user_input")
item_input = Input(shape=(1,), dtype="int32", name="item_input")
user_embeds = Embedding(
input_dim=num_users,
output_dim=embed_size,
embeddings_initializer="normal",
embeddings_regularizer=l2(regs[0]),
input_length=1,
)
item_embeds = Embedding(
input_dim=num_items,
output_dim=embed_size,
embeddings_initializer="normal",
embeddings_regularizer=l2(regs[1]),
input_length=1,
)
user_flattened = Flatten()(user_embeds(user_input))
item_flattened = Flatten()(item_embeds(item_input))
predict_vector = multiply([user_flattened, item_flattened])
prediction = Dense(
1,
activation="sigmoid",
kernel_initializer="lecun_uniform",
name="prediction",
)(predict_vector)
model = Model(inputs=[user_input, item_input], outputs=prediction)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment