Skip to content

Instantly share code, notes, and snippets.

@shivance
Last active June 6, 2023 19:59
Show Gist options
  • Save shivance/76e2d0ab3daa55e962fa54f6db49d648 to your computer and use it in GitHub Desktop.
Save shivance/76e2d0ab3daa55e962fa54f6db49d648 to your computer and use it in GitHub Desktop.
Rotary Embedding
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
def build(self, input_shape):
super().build(input_shape)
self.inverse_freq = self.add_weight(
"inverse_freq", shape=(self.hidden_dim // 2,), dtype=tf.float32
)
self.inverse_freq.assign(
1.0 / (10000 ** (tf.range(start=0, limit=self.hidden_dim, delta=2, dtype=tf.float32) / self.hidden_dim))
)
def apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
cos_emb = cos_emb[:, :tf.shape(tensor)[1], : , :]
sin_emb = sin_emb[:, :tf.shape(tensor)[1], :, :]
x1, x2 = tf.split(tensor, 2, axis=-1)
half_rot_tensor = tf.concat((-x2, x1), axis=-1)
ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
return ret
def _compute_cos_sin_embedding(self, x, seq_dim=1):
seq_len = tf.shape(x)[seq_dim]
tensor = tf.range(seq_len, dtype=self.inverse_freq.dtype)
freqs = tf.einsum("i, j -> ij", tensor, self.inverse_freq)
embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :]
return tf.cos(embedding), tf.sin(embedding)
def call(self, query, key):
cos_emb, sin_emb = self._compute_cos_sin_embedding(key, seq_dim=1)
q_emb = self.apply_rotary_pos_emb(query, cos_emb, sin_emb)
k_emb = self.apply_rotary_pos_emb(key, cos_emb, sin_emb)
return q_emb, k_emb
@shivance
Copy link
Author

shivance commented Jun 6, 2023

Thanks @mattdangerw for helping fix this!

@shivance
Copy link
Author

shivance commented Jun 6, 2023

Note: In keras-nlp design

hidden_dim = 16
batch_size = 32
num_heads = 8
query_length = 256
key_length = 256
head_size = hidden_dim // num_heads

query = tf.ones((batch_size, query_length, num_heads, head_size))
key = tf.ones((batch_size, query_length, num_heads, head_size))
value = tf.ones((batch_size, query_length, num_heads, head_size))
attention_mask = tf.ones((batch_size, query_length, key_length))
hidden_states = tf.ones((batch_size, key_length, hidden_dim))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment