Skip to content

Instantly share code, notes, and snippets.

@segyges
Created June 24, 2025 03:23
Show Gist options
  • Select an option

  • Save segyges/ced13b74ebc4b5c31a7aa55c80414bc8 to your computer and use it in GitHub Desktop.

Select an option

Save segyges/ced13b74ebc4b5c31a7aa55c80414bc8 to your computer and use it in GitHub Desktop.
# Originally by https://jerryxio.ng/
class MultiPositionRotary(nn.Module):
def __init__(
self,
head_dim: int,
pos_dim: int,
min_freq: float,
max_freq: float,
frozen: bool = True,
):
super().__init__()
self.head_dim = head_dim
self.pos_dim = pos_dim
nfreqs = head_dim // 2
freqs = torch.randn(nfreqs, pos_dim)
freqs = freqs / (freqs.pow(2).mean(dim=1, keepdim=True) + 1e-7).sqrt()
freqs = freqs * (
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, nfreqs)
).unsqueeze(-1)
self.freqs_FP = nn.Parameter(freqs, requires_grad=not frozen)
def forward(self, x_NTHD: torch.Tensor, pos_NTP: torch.Tensor):
"""
H: nheads
P: pos_dim
D: head_dim
F: nfreqs == head_dim // 2
"""
assert x_NTHD.size(-1) == self.head_dim
assert pos_NTP.size(-1) == self.pos_dim
theta_NTF = (self.freqs_FP * pos_NTP.unsqueeze(-2)).mean(dim=-1)
cos_NTF = torch.cos(theta_NTF)
sin_NTF = torch.sin(theta_NTF)
x_NTHF, y_NTHF = x_NTHD.float().chunk(2, dim=-1)
x_out_NTHF = x_NTHF * cos_NTF[..., None, :] - y_NTHF * sin_NTF[..., None, :]
y_out_NTHF = x_NTHF * sin_NTF[..., None, :] + y_NTHF * cos_NTF[..., None, :]
return torch.cat([x_out_NTHF, y_out_NTHF], dim=-1).type_as(x_NTHD)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment