Skip to content

Instantly share code, notes, and snippets.

@arthurdouillard
Created May 30, 2021 15:37
Show Gist options
  • Save arthurdouillard/f16d720ca5e3f4e39ad033f632bbf31d to your computer and use it in GitHub Desktop.
Save arthurdouillard/f16d720ca5e3f4e39ad033f632bbf31d to your computer and use it in GitHub Desktop.
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
self.Wq = nn.Linear(dim, num_heads * dim, bias=False)
self.Wk = nn.Linear(dim, num_heads * dim, bias=False)
self.Wv = nn.Linear(dim, num_heads * dim, bias=False)
self.Wo = nn.Linear(num_heads * dim, dim)
def forward(self, x):
# X is of shape (Batch size, number of tokens, embedding dimension)
B, T, D = x.shape
q = self.Wq(x) # (Batch size, number of tokens, number of heads * embedding dimension)
k = self.Wk(x)
v = self.Wv(x)
a = torch.softmax(torch.bmm(q, k.permute(1, 2)) / math.sqrt(D), dim=-1)
z = torch.bmm(a, v)
return self.Wo(z) # (Batch size, number of tokens, embedding dimension)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment