Skip to content

Instantly share code, notes, and snippets.

@janhuenermann
Last active June 28, 2021 14:31
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save janhuenermann/af8c49fed26c18e2cd5150ab16de12a0 to your computer and use it in GitHub Desktop.
Save janhuenermann/af8c49fed26c18e2cd5150ab16de12a0 to your computer and use it in GitHub Desktop.
from torch import nn, chunk, einsum
from torch.nn.functional import softmax
from math import sqrt
class TinyAttention(nn.Module):
def __init__(self, d_attn: int, d_ffn: int):
super().__init__()
self.proj_qkv = nn.Linear(d_ffn, 3 * d_attn)
self.proj_ffn = nn.Linear(d_attn, d_ffn)
def forward(self, x):
q, k, v = chunk(self.proj_qkv(x), 3, dim=-1)
w = einsum("bnd,bmd->bnm", q, k)
a = softmax(w / sqrt(q.size(-1)), dim=-1)
x = einsum("bnm,bmd->bnd", a, v)
return self.proj_ffn(x)
# Test
layer = TinyAttention(32, 64)
x = torch.randn((8, 12, 64))
x = layer(x)
x.shape # torch.Size([8, 12, 64])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment