Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@goddoe
Last active September 15, 2019 11:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goddoe/69cbde1ec3f9c67dcb829b73a4a9ec90 to your computer and use it in GitHub Desktop.
Save goddoe/69cbde1ec3f9c67dcb829b73a4a9ec90 to your computer and use it in GitHub Desktop.
SelfAttention
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, input_dim, output_dim, dropout=0.1):
super(SelfAttention, self).__init__()
self.q = nn.Linear(input_dim, output_dim)
self.k = nn.Linear(input_dim, output_dim)
self.v = nn.Linear(input_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# B x ... x S x D
q = self.q(x)
k = self.k(x)
v = self.v(x)
# (B x ... x S x D) @ (B x ... x D x S) => (B x ... x S x S)
alpha = self.dropout(torch.softmax(q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)), dim=-1))
# (B x ... x S x S) @ (B x ... x S x D) => (B x ... x S x D)
return alpha @ v
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment