Skip to content

Instantly share code, notes, and snippets.

@Diyago
Last active November 5, 2020 18:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Diyago/35d208cb62af40a729c3db04b5c077dc to your computer and use it in GitHub Desktop.
Save Diyago/35d208cb62af40a729c3db04b5c077dc to your computer and use it in GitHub Desktop.
class SimpleGNN(torch.nn.Module):
"""Original from http://pages.di.unipi.it/citraro/files/slides/Landolfi_tutorial.pdf"""
def __init__(self, dataset, hidden=64, layers=6):
super(SimpleGNN, self).__init__()
self.dataset = dataset
self.convs = torch.nn.ModuleList()
self.convs.append(GCNConv(in_channels=dataset.num_node_features,
out_channels=hidden))
for _ in range(1, layers):
self.convs.append(GCNConv(in_channels=hidden, out_channels=hidden))
self.jk = JumpingKnowledge(mode="cat")
self.jk_lin = torch.nn.Linear(
in_features=hidden*layers, out_features=hidden)
self.lin_1 = torch.nn.Linear(in_features=hidden, out_features=hidden)
self.lin_2 = torch.nn.Linear(
in_features=hidden, out_features=dataset.num_classes)
def forward(self, index):
data = Batch.from_data_list(self.dataset[index])
x = data.x
xs = []
for conv in self.convs:
x = F.relu(conv(x=x, edge_index=data.edge_index))
xs.append(x)
x = self.jk(xs)
x = F.relu(self.jk_lin(x))
x = global_add_pool(x, batch=data.batch)
x = F.relu(self.lin_1(x))
x = F.softmax(self.lin_2(x), dim=-1)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment