Skip to content

Instantly share code, notes, and snippets.

@andreaskoepf
Created August 13, 2023 15:10
Show Gist options
  • Save andreaskoepf/678c96074eb95b0efb62cd3d4bc4e899 to your computer and use it in GitHub Desktop.
Save andreaskoepf/678c96074eb95b0efb62cd3d4bc4e899 to your computer and use it in GitHub Desktop.
from typing import Optional
import torch
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, scaling_factor: float = 1.0
) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end, device=freqs.device).float() / scaling_factor # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
return torch.polar(torch.ones_like(freqs), freqs) # complex64
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[0], x.shape[-1])
shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
if position_ids is None:
# we assume position_ids to be torch.arange(0, seq_[eng])
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# freqs_cis: [seq_len, 1, 1, head_dim//2] (complex64)
else:
# use specified position_ids, possibly not monotonically increasing
# tensor shapes & tpyes:
# xq_: [seq_len, batch_size, heads, head_dim//2] (complex64)
# position_ids: [batch_size, seq_len] (long)
assert position_ids.shape == (xq_.shape[1], xq_.shape[0])
assert (freqs_cis.shape[1] == xq_.shape[-1])
freqs_cis = freqs_cis[position_ids].transpose(0, 1).unsqueeze(-2)
# freqs_cis: [seq_len, batch_size, 1, head_dim//2] (complex64)
freqs_cis = freqs_cis.to(xq.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# from modeling_llama of transformers
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
# transformers code
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def apply_rot_transformers(q, k, v: torch.tensor, position_ids: torch.tensor, scaling_factor:float = 1.0):
batch_size, heads, seq_len, head_dim = v.shape
s1 = LlamaLinearScalingRotaryEmbedding(head_dim, scaling_factor=scaling_factor, base=10000)
cos,sin = s1(v, seq_len=seq_len)
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
def apply_rot_mt(q, k: torch.tensor, position_ids: torch.tensor, scaling_factor: float = 1.0):
seq_len, batch_size, heads, head_dim = q.shape
freqs = precompute_freqs_cis(dim=head_dim, end=seq_len, theta=10000.0, scaling_factor=scaling_factor)
print('freqs', freqs.shape, freqs.dtype)
print('q', q.shape)
print('k', k.shape)
return apply_rotary_emb(q, k, freqs, position_ids)
def main():
batch_size = 3
heads = 1
head_dim = 32
seq_len = 10
scaling_factor = 1.0
v = torch.zeros(batch_size, heads, seq_len, head_dim)
q = torch.ones(batch_size, heads, seq_len, head_dim)
k = torch.ones(batch_size, heads, seq_len, head_dim)
position_ids = torch.arange(seq_len)#.flip(0)
position_ids = position_ids.repeat(batch_size, 1)
print('position_ids', position_ids)
q1, k1 = apply_rot_transformers(q, k, v, position_ids, scaling_factor=scaling_factor)
q = torch.ones(seq_len, batch_size, heads, head_dim)
k = torch.ones(seq_len, batch_size, heads, head_dim)
q2, k2 = apply_rot_mt(q, k, position_ids, scaling_factor=scaling_factor)
# convert MT -> sliced rotary HF format
q2 = q2.permute(1,2,0,3)
q3 = torch.zeros_like(q2)
q3[..., :head_dim//2] = q2[..., 0::2]
q3[..., head_dim//2:] = q2[..., 1::2]
print('diff:', {(q1 - q3).abs().sum().item()})
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment