Skip to content

Instantly share code, notes, and snippets.

Last active November 18, 2019 04:38
Show Gist options
  • Save liuyijiang1994/4cb2391c6300e8147ee978ef3aef412b to your computer and use it in GitHub Desktop.
Save liuyijiang1994/4cb2391c6300e8147ee978ef3aef412b to your computer and use it in GitHub Desktop.
class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
def forward(self, x, adj):
x = F.dropout(x, self.dropout,
x =[att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout,
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)
class GraphAttentionLayer(nn.Module):
Simple GAT layer, similar to
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_uniform_(, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
nn.init.xavier_uniform_(, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, input, adj):
h =, self.W)
N = h.size()[0]
a_input =[h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout,
h_prime = torch.matmul(attention, h)
if self.concat:
return F.elu(h_prime)
return h_prime
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment