Skip to content

Instantly share code, notes, and snippets.

@ita9naiwa
Last active December 11, 2020 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ita9naiwa/2207f530c1a4c5275d5e39ba593614df to your computer and use it in GitHub Desktop.
Save ita9naiwa/2207f530c1a4c5275d5e39ba593614df to your computer and use it in GitHub Desktop.
MultiHeadAttention Block of Graph Attention Network(https://arxiv.org/abs/1710.10903)
class GATMHA(nn.Module):
def __init__(self, hidden_size, n_heads, lralpha=0.2):
super(GATMHA, self).__init__()
self.W = nn.Linear(hidden_size, hidden_size, bias=True)
self.Q = nn.Linear(hidden_size, hidden_size, bias=True)
self.a = nn.Parameter(torch.FloatTensor(n_heads, 2 * (hidden_size // n_heads)))
self.lralpha = lralpha
self.n_heads = n_heads
self.n_hidden_per_head = hidden_size // n_heads
nn.init.uniform_(self.W.weight, -1 / np.sqrt(hidden_size), 1 / np.sqrt(hidden_size))
nn.init.uniform_(self.Q.weight, -1 / np.sqrt(hidden_size), 1 / np.sqrt(hidden_size))
nn.init.uniform_(self.a.data, -1 / np.sqrt(hidden_size), 1 / np.sqrt(hidden_size))
self.activation = nn.LeakyReLU(self.lralpha)
def forward(self, x, mask):
"""
Parameters
----------
x : torch.FloatTensor, [batch_size, n_seq, hidden_size]
mask : torch.BoolTensor, [batch_size, n_seq, n_seq]
Returns
-------
y: torch.FloatTensor, [batch_size, n_seq, hidden_size]
"""
bs, n_seq, hidden_size = x.shape
w = self.W(x).reshape(bs, n_seq, self.n_hidden_per_head, self.n_heads).unsqueeze(2).repeat(1, 1, n_seq, 1, 1)
q = self.Q(x).reshape(bs, n_seq, self.n_hidden_per_head, self.n_heads)
catted = torch.cat([w, w.permute(0, 2, 1, 3, 4)], -2)
weight = torch.einsum("bijkh,hk->bijh", catted, self.a) * (1.0 - mask.float()).unsqueeze(-1)
weight = torch.exp(self.activation(weight))
weight_sum = weight.sum(-2)
final_weight = weight / (1e-12 + weight_sum.unsqueeze(-2))
r = torch.tanh(torch.matmul(final_weight.permute(0, 3, 1, 2), q.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)).squeeze(-1)
return r
@ita9naiwa
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment