Skip to content

Instantly share code, notes, and snippets.

@jeongukjae
Last active September 28, 2020 07:26
Show Gist options
  • Save jeongukjae/a650b45b3af1d34e90607212c42fb0fa to your computer and use it in GitHub Desktop.
Save jeongukjae/a650b45b3af1d34e90607212c42fb0fa to your computer and use it in GitHub Desktop.
import tensorflow as tf
class TransformerEncoder(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, hidden_state):
return hidden_state
if __name__ == "__main__":
hidden_size = 256
num_heads = 2
intermediate_size = 768
activation_function = "relu"
input_node = tf.keras.Input((None, hidden_size), dtype=tf.float32)
output_node = TransformerEncoder()(input_node)
model = tf.keras.Model(input_node, output_node, name="model")
tf.keras.callbacks.TensorBoard(log_dir="./logs").set_model(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment