Skip to content

Instantly share code, notes, and snippets.

@eliorc
Created July 7, 2019 07:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eliorc/6506948a4eb1a32abb27c7093654cfce to your computer and use it in GitHub Desktop.
Save eliorc/6506948a4eb1a32abb27c7093654cfce to your computer and use it in GitHub Desktop.
CDME
class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
def __init__(self,
embedding_matrices: List[tf.keras.layers.Embedding],
output_dim: Optional[int] = None,
n_lstm_units: int = 2,
name: str = 'contextual_dynamic_meta_embedding',
**kwargs):
"""
:param embedding_matrices: List of embedding layers
:param n_lstm_units: Number of units in each LSTM, (notated as `m` in the original article)
:param output_dim: Dimension of the output embedding
:param name: Layer name
"""
super().__init__(name=name, **kwargs)
# Validate all the embedding matrices have the same vocabulary size
if not len(set((e.input_dim for e in embedding_matrices))) == 1:
raise ValueError('Vocabulary sizes (first dimension) of all embedding matrices must match')
# If no output_dim is supplied, use the maximum dimension from the given matrices
self.output_dim = output_dim or min([e.output_dim for e in embedding_matrices])
self.n_lstm_units = n_lstm_units
self.embedding_matrices = embedding_matrices
self.n_embeddings = len(self.embedding_matrices)
self.projections = [tf.keras.layers.Dense(units=self.output_dim,
activation=None,
name='projection_{}'.format(i),
dtype=self.dtype) for i, e in enumerate(self.embedding_matrices)]
self.bilstm = tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(units=self.n_lstm_units, return_sequences=True),
name='bilstm',
dtype=self.dtype)
self.attention = tf.keras.layers.Dense(units=1,
activation=None,
name='attention',
dtype=self.dtype)
def call(self, inputs,
**kwargs) -> tf.Tensor:
batch_size, time_steps = inputs.shape[:2]
# Embedding lookup
embedded = [e(inputs) for e in self.embedding_matrices] # List of shape=(batch_size, time_steps, channels_i)
# Projection
projected = tf.reshape(tf.concat([p(e) for p, e in zip(self.projections, embedded)], axis=-1),
# Project embeddings
shape=(batch_size, time_steps, -1, self.output_dim),
name='projected') # shape=(batch_size, time_steps, n_embeddings, output_dim)
# Contextualize
context = self.bilstm(
tf.reshape(projected, shape=(batch_size * self.n_embeddings, time_steps,
self.output_dim))) # shape=(batch_size * n_embeddings, time_steps, n_lstm_units*2)
context = tf.reshape(context, shape=(batch_size, time_steps, self.n_embeddings,
self.n_lstm_units * 2)) # shape=(batch_size, time_steps, n_embeddings, n_lstm_units*2)
# Calculate attention coefficients
alphas = self.attention(context) # shape=(batch_size, time_steps, n_embeddings, 1)
alphas = tf.nn.softmax(alphas, axis=-2) # shape=(batch_size, time_steps, n_embeddings, 1)
# Attend
output = tf.squeeze(tf.matmul(
tf.transpose(projected, perm=[0, 1, 3, 2]), alphas), # Attending
name='output') # shape=(batch_size, time_steps, output_dim)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment