Created
August 13, 2023 15:10
-
-
Save andreaskoepf/678c96074eb95b0efb62cd3d4bc4e899 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import torch | |
def precompute_freqs_cis( | |
dim: int, end: int, theta: float = 10000.0, scaling_factor: float = 1.0 | |
) -> torch.Tensor: | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
t = torch.arange(end, device=freqs.device).float() / scaling_factor # type: ignore | |
freqs = torch.outer(t, freqs).float() # type: ignore | |
return torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
ndim = x.ndim | |
assert 0 <= 1 < ndim | |
assert freqs_cis.shape == (x.shape[0], x.shape[-1]) | |
shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
return freqs_cis.view(*shape) | |
def apply_rotary_emb( | |
xq: torch.Tensor, | |
xk: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
position_ids: Optional[torch.Tensor] = None, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
if position_ids is None: | |
# we assume position_ids to be torch.arange(0, seq_[eng]) | |
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
# freqs_cis: [seq_len, 1, 1, head_dim//2] (complex64) | |
else: | |
# use specified position_ids, possibly not monotonically increasing | |
# tensor shapes & tpyes: | |
# xq_: [seq_len, batch_size, heads, head_dim//2] (complex64) | |
# position_ids: [batch_size, seq_len] (long) | |
assert position_ids.shape == (xq_.shape[1], xq_.shape[0]) | |
assert (freqs_cis.shape[1] == xq_.shape[-1]) | |
freqs_cis = freqs_cis[position_ids].transpose(0, 1).unsqueeze(-2) | |
# freqs_cis: [seq_len, batch_size, 1, head_dim//2] (complex64) | |
freqs_cis = freqs_cis.to(xq.device) | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
# from modeling_llama of transformers | |
class LlamaRotaryEmbedding(torch.nn.Module): | |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
super().__init__() | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
# Build here to make `torch.jit.trace` work. | |
self._set_cos_sin_cache( | |
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() | |
) | |
def _set_cos_sin_cache(self, seq_len, device, dtype): | |
self.max_seq_len_cached = seq_len | |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
# Different from paper, but it uses a different permutation in order to obtain the same calculation | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) | |
def forward(self, x, seq_len=None): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
if seq_len > self.max_seq_len_cached: | |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) | |
return ( | |
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | |
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | |
) | |
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): | |
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" | |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): | |
self.scaling_factor = scaling_factor | |
super().__init__(dim, max_position_embeddings, base, device) | |
def _set_cos_sin_cache(self, seq_len, device, dtype): | |
self.max_seq_len_cached = seq_len | |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | |
t = t / self.scaling_factor | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
# Different from paper, but it uses a different permutation in order to obtain the same calculation | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) | |
# transformers code | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def apply_rot_transformers(q, k, v: torch.tensor, position_ids: torch.tensor, scaling_factor:float = 1.0): | |
batch_size, heads, seq_len, head_dim = v.shape | |
s1 = LlamaLinearScalingRotaryEmbedding(head_dim, scaling_factor=scaling_factor, base=10000) | |
cos,sin = s1(v, seq_len=seq_len) | |
return apply_rotary_pos_emb(q, k, cos, sin, position_ids) | |
def apply_rot_mt(q, k: torch.tensor, position_ids: torch.tensor, scaling_factor: float = 1.0): | |
seq_len, batch_size, heads, head_dim = q.shape | |
freqs = precompute_freqs_cis(dim=head_dim, end=seq_len, theta=10000.0, scaling_factor=scaling_factor) | |
print('freqs', freqs.shape, freqs.dtype) | |
print('q', q.shape) | |
print('k', k.shape) | |
return apply_rotary_emb(q, k, freqs, position_ids) | |
def main(): | |
batch_size = 3 | |
heads = 1 | |
head_dim = 32 | |
seq_len = 10 | |
scaling_factor = 1.0 | |
v = torch.zeros(batch_size, heads, seq_len, head_dim) | |
q = torch.ones(batch_size, heads, seq_len, head_dim) | |
k = torch.ones(batch_size, heads, seq_len, head_dim) | |
position_ids = torch.arange(seq_len)#.flip(0) | |
position_ids = position_ids.repeat(batch_size, 1) | |
print('position_ids', position_ids) | |
q1, k1 = apply_rot_transformers(q, k, v, position_ids, scaling_factor=scaling_factor) | |
q = torch.ones(seq_len, batch_size, heads, head_dim) | |
k = torch.ones(seq_len, batch_size, heads, head_dim) | |
q2, k2 = apply_rot_mt(q, k, position_ids, scaling_factor=scaling_factor) | |
# convert MT -> sliced rotary HF format | |
q2 = q2.permute(1,2,0,3) | |
q3 = torch.zeros_like(q2) | |
q3[..., :head_dim//2] = q2[..., 0::2] | |
q3[..., head_dim//2:] = q2[..., 1::2] | |
print('diff:', {(q1 - q3).abs().sum().item()}) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment