Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
from torch import nn, chunk, einsum
from torch.nn.functional import softmax
from math import sqrt
class TinyAttention(nn.Module):
def __init__(self, d_attn: int, d_ffn: int):
super().__init__()
self.proj_qkv = nn.Linear(d_ffn, 3 * d_attn)
self.proj_ffn = nn.Linear(d_attn, d_ffn)
def forward(self, x):
q, k, v = chunk(self.proj_qkv(x), 3, dim=-1)
w = einsum("bnd,bmd->bnm", q, k)
a = softmax(w / sqrt(q.size(-1)), dim=-1)
x = einsum("bnm,bmd->bnd", a, v)
return self.proj_ffn(x)
# Test
layer = TinyAttention(32, 64)
x = torch.randn((8, 12, 64))
x = layer(x)
x.shape # torch.Size([8, 12, 64])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment