Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created November 22, 2022 16:35
Show Gist options
  • Save grey-area/1c0eecfe1186defbcb36208010c0d673 to your computer and use it in GitHub Desktop.
Save grey-area/1c0eecfe1186defbcb36208010c0d673 to your computer and use it in GitHub Desktop.
import torch
import math
from torch import Tensor
from typing import Optional
def get_relative_positional_encoding(length1:int, length2:int, d_model:int, device:torch.device):
xs = torch.arange(length1, device=device).unsqueeze(1)
ys = torch.arange(length2, device=device).unsqueeze(0)
position = ys - xs
div_term = torch.exp(torch.arange(0, d_model, 2, device=device) * (-math.log(10000.0) / d_model))
angle = position.unsqueeze(-1) * div_term.view(1, 1, -1)
positional_encoding = torch.cat((torch.sin(angle), torch.cos(angle)), dim=-1)
return positional_encoding / math.sqrt(d_model)
def scaled_dot_product_relative_attention(
query: Tensor,
k: Tensor,
v: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
r"""
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.
Returns a tensor pair containing attended values and attention weights.
Args:
q, k, v: query, key and value tensors. See Shape section for shape details.
attn_mask: optional tensor containing mask values to be added to calculated
attention. May be 2D or 3D; see Shape section for details.
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
Shape:
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
and E is embedding dimension.
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
and E is embedding dimension.
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
and E is embedding dimension.
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
shape :math:`(Nt, Ns)`.
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
have shape :math:`(B, Nt, Ns)`
"""
B, Nt, E = query.shape
Ns = k.size(1)
query = query / math.sqrt(E)
p = get_relative_positional_encoding(Nt, Ns, E, query.device)
attn = torch.bmm(query, k.transpose(-2, -1)) + torch.einsum('blf,lmf->blm', query, p)
if attn_mask is not None:
attn = attn + attn_mask
attn = softmax(attn, dim=-1)
if dropout_p > 0.0:
attn = dropout(attn, p=dropout_p)
output = torch.bmm(attn, v) + torch.einsum('blm,lmf->blf', attn, p)
return output, attn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment