Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 20, 2021 12:58
Show Gist options
  • Save lazuxd/d6a5c8b934b5999b282cd9b29a5fe723 to your computer and use it in GitHub Desktop.
Save lazuxd/d6a5c8b934b5999b282cd9b29a5fe723 to your computer and use it in GitHub Desktop.
def _call_lstm(self,
level: int,
x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
n = x.shape[0]
self.a[level] = self.a[level][0:n]
self.c[level] = self.c[level][0:n]
concat_matrix = tf.concat([self.a[level], x], axis=1)
update_gate = tf.math.sigmoid(
tf.linalg.matmul(concat_matrix, self.wu[level])
+ self.bu[level])
forget_gate = tf.math.sigmoid(
tf.linalg.matmul(concat_matrix, self.wf[level])
+ self.bf[level])
output_gate = tf.math.sigmoid(
tf.linalg.matmul(concat_matrix, self.wo[level])
+ self.bo[level])
c_candidate = tf.math.tanh(
tf.linalg.matmul(concat_matrix, self.wc[level])
+ self.bc[level])
self.c[level] = (tf.math.multiply(update_gate, c_candidate) +
tf.math.multiply(forget_gate, self.c[level]))
self.a[level] = tf.math.multiply(output_gate, tf.math.tanh(self.c[level]))
return self.a[level]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment