Skip to content

Instantly share code, notes, and snippets.

@raphant
Created April 18, 2024 04:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raphant/e64e635fdce50065600309a187365e9f to your computer and use it in GitHub Desktop.
Save raphant/e64e635fdce50065600309a187365e9f to your computer and use it in GitHub Desktop.
Model3
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
super().__init__(**kwargs)
self.token_embeddings = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.position_embeddings = layers.Embedding(input_dim=sequence_length, output_dim=embed_dim)
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.embed_dim = embed_dim
def call(self, inputs):
length = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=length, delta=1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
return embedded_tokens + embedded_positions
class DecoderBlock(layers.Layer):
def __init__(self, embed_dim, num_heads, dropout_rate, **kwargs):
super().__init__(**kwargs)
self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.dropout1 = layers.Dropout(dropout_rate)
self.norm1 = layers.LayerNormalization()
self.dense1 = layers.Dense(embed_dim, activation="relu")
self.dense2 = layers.Dense(embed_dim)
self.dropout2 = layers.Dropout(dropout_rate)
self.norm2 = layers.LayerNormalization()
def call(self, inputs, training=False):
attn_output = self.attention(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.norm1(inputs + attn_output)
dense_output = self.dense1(out1)
dense_output = self.dense2(dense_output)
dense_output = self.dropout2(dense_output, training=training)
return self.norm2(out1 + dense_output)
class TransformerDecoder(layers.Layer):
def __init__(self, embed_dim, num_heads, num_blocks, dropout_rate, **kwargs):
super().__init__(**kwargs)
self.num_blocks = num_blocks
self.blocks = [DecoderBlock(embed_dim, num_heads, dropout_rate) for _ in range(num_blocks)]
def call(self, inputs, training=False):
x = inputs
for block in self.blocks:
x = block(x, training=training)
return x
def transformer_model(vocab_size, sequence_length, embed_dim, num_heads, num_blocks, dropout_rate, num_classes):
inputs = keras.Input(shape=(None,), dtype="int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerDecoder(embed_dim, num_heads, num_blocks, dropout_rate)(x, training=True)
last_step_output = layers.Lambda(lambda x: x[:, -1, :])(x)
# # Output Layer
# outputs = Dense(vocab_size, activation='softmax')(last_step_output)
outputs = layers.Dense(num_classes, activation="softmax")(last_step_output)
return keras.Model(inputs, outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment