Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 20, 2021 12:57
Show Gist options
  • Save lazuxd/471863f336736821c6076a7e6056977d to your computer and use it in GitHub Desktop.
Save lazuxd/471863f336736821c6076a7e6056977d to your computer and use it in GitHub Desktop.
def _call_gru(self,
level: int,
x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
n = x.shape[0]
self.a[level] = self.a[level][0:n]
concat_matrix = tf.concat([self.a[level], x], axis=1)
relevance_gate = tf.math.sigmoid(
tf.linalg.matmul(concat_matrix, self.wr[level])
+ self.br[level])
update_gate = tf.math.sigmoid(
tf.linalg.matmul(concat_matrix, self.wu[level])
+ self.bu[level])
a_candidate = tf.math.tanh(
tf.linalg.matmul(
tf.concat([tf.math.multiply(relevance_gate, self.a[level]), x], axis=1),
self.wa[level])
+ self.ba[level])
self.a[level] = (tf.math.multiply(update_gate, a_candidate) +
tf.math.multiply((1-update_gate), self.a[level]))
return self.a[level]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment