Skip to content

Instantly share code, notes, and snippets.

@eustin
Created June 8, 2020 04:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eustin/f568c7157b54fecd623f1ada5e8e101a to your computer and use it in GitHub Desktop.
Save eustin/f568c7157b54fecd623f1ada5e8e101a to your computer and use it in GitHub Desktop.
class ExpandBatchLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(ExpandBatchLayer, self).__init__(**kwargs)
def call(self, input):
queries, docs = input
batch, num_docs, embedding_dims = tf.unstack(tf.shape(docs))
expanded_queries = tf.gather(queries, tf.zeros([num_docs], tf.int32), axis=1)
return tf.concat([expanded_queries, docs], axis=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment