Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created August 3, 2019 14:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/4f6e26f2f06943f1c23eb302d47248e4 to your computer and use it in GitHub Desktop.
Save NMZivkovic/4f6e26f2f06943f1c23eb302d47248e4 to your computer and use it in GitHub Desktop.
class MaskHandler(object):
def padding_mask(self, sequence):
sequence = tf.cast(tf.math.equal(sequence, 0), tf.float32)
return sequence[:, tf.newaxis, tf.newaxis, :]
def look_ahead_mask(self, size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment