Skip to content

Instantly share code, notes, and snippets.

@morrisalp
Last active February 2, 2021 09:12
Show Gist options
  • Save morrisalp/39a20385f8597c79b38a41cd61a756f5 to your computer and use it in GitHub Desktop.
Save morrisalp/39a20385f8597c79b38a41cd61a756f5 to your computer and use it in GitHub Desktop.
minimal TF 2.0 (+ Keras) example of a transformer, based on the Peter Bloem article "Transformers from Scratch" (http://www.peterbloem.nl/blog/transformers)
from tensorflow.keras.layers import Input, Dense, Lambda, Reshape, Activation, Layer, LayerNormalization, Add
from tensorflow.keras.models import Sequential
from tensorflow.keras import Model
import tensorflow as tf
class SelfAttention(Layer):
def __init__(self, heads = 8):
super().__init__()
self.heads = heads
def build(self, input_shape):
# expects input of shape (b, t, k)
# [b: batch dimension, t: time step, k: embedding dimension]
super().build(input_shape)
_, t, k = input_shape
R = Reshape((t, self.heads, k))
L = Lambda(lambda x: x / (k ** 0.25))
self.to_keys = Sequential([Dense(k * self.heads, use_bias = False), R, L])
self.to_queries = Sequential([Dense(k * self.heads, use_bias = False), R, L])
self.to_values = Sequential([Dense(k * self.heads, use_bias = False), R])
self.softmax = Activation('softmax')
self.unify_heads = Sequential([Reshape((t, self.heads * k)), Dense(k)])
def call(self, x):
K, Q, V = self.to_keys(x), self.to_queries(x), self.to_values(x)
A = self.softmax(tf.einsum('bthk,bThK->bthT', Q, K))
R = tf.einsum('bthT,bThk->bthk', A, V)
return self.unify_heads(R)
class TransformerBlock(Layer):
def __init__(self, heads = 8, ff_hidden_mult = 4):
super().__init__()
self.heads = heads
self.ff_hidden_mult = ff_hidden_mult
def build(self, input_shape):
# expects input of shape (b, t, k)
# [b: batch dimension, t: time step, k: embedding dimension]
super().build(input_shape)
_, t, k = input_shape
self.hidden_dim = self.ff_hidden_mult * k
self.attention = SelfAttention(heads = self.heads)
self.res_norm1 = Sequential([Add(), LayerNormalization()])
self.res_norm2 = Sequential([Add(), LayerNormalization()])
self.feedforward = Sequential([Dense(self.hidden_dim, activation = 'relu'), Dense(k)])
def call(self, x):
A = self.attention(x)
R = self.res_norm1([A, x])
F = self.feedforward(R)
return self.res_norm2([F, R])
# example of TransformerBlock layer in model:
seq_length = 25
embedding_dim = 10
inputs = Input(shape = (seq_length, embedding_dim))
tb = TransformerBlock()(inputs)
model = Model(inputs = inputs, outputs = tb)
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment