Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Two-dimensional positional encoding in PyTorch (inspired by https://arxiv.org/abs/1706.03762)
import torch
from typing import Tuple, Optional
@torch.jit.script
def positional_encoding_2d(shape: Tuple[int, int, int], temperature: float = 1e4, scale: float = 2*math.pi,
dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None):
"""Returns the two-dimensional positional encoding as shape [d_model, h, w]"""
d_model, h, w = shape[-3:]
i = torch.arange(d_model // 4, dtype=dtype, device=device)
ys = torch.arange(h, dtype=dtype, device=device) / (h - 1) * scale
xs = torch.arange(w, dtype=dtype, device=device) / (w - 1) * scale
t = (temperature ** (4. / d_model * i)).view(-1,1,1,1,1).expand(-1,2,-1,-1,-1)
u = torch.cat((xs.expand(1, h, w), ys.unsqueeze(-1).expand(1, h, w)), -3) / t
u[:, 0] = u[:, 0].sin()
u[:, 1] = u[:, 1].cos()
return u.view(-1, h, w) # with channel format: sin(x0) sin(y0) cos(x0) cos(y0) sin(x1) ...
@torch.jit.script
def positional_encoding_2d_as(x: torch.Tensor, temperature: float = 1e4, scale: float = 2*math.pi):
d, h, w = x.shape[-3:]
return positional_encoding_2d((d, h, w), temperature, scale, x.dtype, x.device).expand_as(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment