Skip to content

Instantly share code, notes, and snippets.

@VXU1230
Last active March 19, 2019 23:00
Show Gist options
  • Save VXU1230/cf995223abedf0e546001fdaace63255 to your computer and use it in GitHub Desktop.
Save VXU1230/cf995223abedf0e546001fdaace63255 to your computer and use it in GitHub Desktop.
build model
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embed_size):
super(MyModel, self).__init__()
self.target_inputs = layers.Input((1,))
self.context_inputs = layers.Input((1,))
self.embedding = layers.Embedding(
vocab_size,
embed_size,
embeddings_initializer=tf.keras.initializers.glorot_normal(),
name='embedding')
self.reshape = layers.Reshape((embed_size, 1))
self.reshape2 = layers.Reshape((1,))
self.dense = layers.Dense(1, activation='sigmoid')
@tf.function
def call(self, target, context):
target = self.embedding(target)
target = self.reshape(target)
context = self.embedding(context)
context = self.reshape(context)
dot_product = layers.dot([target, context], axes=1)
dot_product = self.reshape2(dot_product)
outputs = self.dense(dot_product)
return outputs
EMBED_SIZE = 300
INIT_LR = 0.05
def build_model():
model = MyModel(VOCAB_SIZE, EMBED_SIZE)
optimizer = tf.keras.optimizers.Nadam(lr=INIT_LR)
loss_fn = tf.keras.losses.BinaryCrossentropy()
return model, optimizer, loss_fn
model, optimizer, loss_fn = build_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment