Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Created June 3, 2020 17:42
Show Gist options
  • Save lucidrains/7ef8b75ae10e7e6af8c375080effd9f7 to your computer and use it in GitHub Desktop.
Save lucidrains/7ef8b75ae10e7e6af8c375080effd9f7 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import torch.nn.functional as F
def psi(x):
return F.elu(x) + 1
class LinearAttention(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
b, t, d = x.shape
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
merge_heads = lambda x: x.reshape(b, t, h, -1).transpose(1, 2).reshape(b * h, t, -1)
q, k, v = map(merge_heads, (q, k, v))
q, k = map(psi, (q, k))
norm_q = q / q.sum(dim=-1, keepdim=True)
context = torch.einsum('bnd,bne->bnde', k, v)
context = context.cumsum(dim=1) / k.cumsum(dim=1).unsqueeze(-1)
out = torch.einsum('bnd,bnde->bne', norm_q, context)
out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1)
return self.to_out(out)
x = torch.randn(1, 1024, 512)
attn = LinearAttention(dim=512, heads=8)
attn(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment