Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created December 6, 2020 19:16
Show Gist options
  • Save tchaton/52feddbb239b924ee1d1f1ae08a82cac to your computer and use it in GitHub Desktop.
Save tchaton/52feddbb239b924ee1d1f1ae08a82cac to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, F_in, F_out):
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))
def forward(self, x, edge_index):
# x has shape [N, F_in]
# edge_index has shape [2, E]
return self.propagate(edge_index, x=x) # shape [N, F_out]
def message(self, x_i, x_j):
# x_i has shape [E, F_in]
# x_j has shape [E, F_in]
edge_features = torch.cat([x_i, x_j - x_i], dim=1) # shape [E, 2 * F_in]
return self.mlp(edge_features) # shape [E, F_out]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment