Skip to content

Instantly share code, notes, and snippets.

@jeremystan
Last active February 13, 2021 15:38
Show Gist options
  • Save jeremystan/c236000a4159f9d47c28784fa6693c45 to your computer and use it in GitHub Desktop.
Save jeremystan/c236000a4159f9d47c28784fa6693c45 to your computer and use it in GitHub Desktop.
A naive Keras architecture for predicting the next item picked in a shopping list.
from keras.models import Model
from keras.layers.core import Dense, Reshape, Lambda
from keras.layers import Input, Embedding, merge
from keras import backend as K
# Number of product IDs available
N_products = 1000000
N_stores = 1000
N_shoppers = 10000
# Integer IDs representing 1-hot encodings
prior_in = Input(shape=(1,))
store_in = Input(shape=(1,))
shopper_in = Input(shape=(1,))
# Dense N-hot encoding for candidate products
candidates_in = Input(shape=(N_products,))
# Embeddings
prior = Embedding(N_products, 10)(prior_in)
store = Embedding(N_stores, 10)(store_in)
shopper = Embedding(N_shoppers, 10)(shopper_in)
# Reshape and merge all embeddings together
reshape = Reshape(target_shape=(10,))
combined = merge([reshape(prior), reshape(store), reshape(shopper)],
mode='concat')
# Hidden layers
hidden_1 = Dense(1024, activation='relu')(combined)
hidden_2 = Dense(512, activation='relu')(hidden_1)
hidden_3 = Dense(256, activation='relu')(hidden_2)
hidden_4 = Dense(10, activation='linear')(hidden_3)
# Final 'fan-out' into the space of future products
final = Dense(N_products, activation='linear')(hidden_4)
# Ensure we do not overflow when we exponentiate
final = Lambda(lambda x: x - K.max(x))(final)
# Masked soft-max using Lambda and merge-multiplication
exponentiate = Lambda(lambda x: K.exp(x))(final)
masked = merge([exponentiate, candidates_in], mode='mul')
predicted = Lambda(lambda x: x / K.sum(x))(masked)
# Compile with categorical crossentropy and adam
mdl = Model(input=[prior_in, store_in, shopper_in, candidates_in],
output=predicted)
mdl.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment