Skip to content

Instantly share code, notes, and snippets.

@chnsh
Created March 15, 2020 02:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chnsh/1c39b8733288260ac811f9a027d9a1c3 to your computer and use it in GitHub Desktop.
Save chnsh/1c39b8733288260ac811f9a027d9a1c3 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Parameter, functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_max, scatter_add
def softmax(src, index, num_nodes=None):
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
num_nodes = maybe_num_nodes(index, num_nodes)
out = src - scatter_max(src, index, dim=1, dim_size=num_nodes)[0][:, index]
out = out.exp()
out = out / (
scatter_add(out, index, dim=1, dim_size=num_nodes)[:, index] + 1e-16)
return out
class GATConv(MessagePassing):
r"""The graph attentional operator from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed as
.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
\right)\right)}.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
negative_slope (float, optional): LeakyReLU angle of the negative
slope. (default: :obj:`0.2`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True,
**kwargs):
super().__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.weight = Parameter(
torch.Tensor(in_channels, heads * out_channels))
self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
if bias and concat:
self.bias = Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
glorot(self.att)
zeros(self.bias)
def forward(self, x, edge_index, size=None):
""""""
if size is None and torch.is_tensor(x):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim))
if torch.is_tensor(x):
x = torch.matmul(x, self.weight)
else:
x = (None if x[0] is None else torch.matmul(x[0], self.weight),
None if x[1] is None else torch.matmul(x[1], self.weight))
return self.propagate(edge_index, size=size, x=x)
def message(self, edge_index_i, x_i, x_j, size_i):
"""
:param edge_index_i: shape: (E,)
:param x_i: (b, E, heads * out_channels)
:param x_j: (b, E, heads * out_channels)
:param size_i:
:return: (b, E, heads, out_channels)
"""
# Reshape x_i and x_j
x_j = x_j.view(-1, edge_index_i.shape[0], self.heads, self.out_channels)
x_i = x_i.view(-1, edge_index_i.shape[0], self.heads, self.out_channels)
# cat x_i and x_j on the last dimension and multiply with shared attention function
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index_i, size_i)
# Sample attention coefficients stochastically.
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * torch.unsqueeze(alpha, dim=-1)
def update(self, aggr_out):
b, num_nodes, num_heads, out_channels = aggr_out.size()
# apply elu activation function according to GAT paper
aggr_out = F.elu(aggr_out)
aggr_out = F.dropout(aggr_out, p=self.dropout, training=self.training)
if self.concat is True:
aggr_out = aggr_out.view(b, num_nodes, self.heads * self.out_channels)
else:
aggr_out = aggr_out.mean(dim=2)
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment