Skip to content

Instantly share code, notes, and snippets.

Created June 3, 2020 17:42
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
import torch
from torch import nn
import torch.nn.functional as F
def psi(x):
return F.elu(x) + 1
class LinearAttention(nn.Module):
def __init__(self, dim, heads):
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
b, t, d = x.shape
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
merge_heads = lambda x: x.reshape(b, t, h, -1).transpose(1, 2).reshape(b * h, t, -1)
q, k, v = map(merge_heads, (q, k, v))
q, k = map(psi, (q, k))
norm_q = q / q.sum(dim=-1, keepdim=True)
context = torch.einsum('bnd,bne->bnde', k, v)
context = context.cumsum(dim=1) / k.cumsum(dim=1).unsqueeze(-1)
out = torch.einsum('bnd,bnde->bne', norm_q, context)
out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1)
return self.to_out(out)
x = torch.randn(1, 1024, 512)
attn = LinearAttention(dim=512, heads=8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment