Created
May 17, 2021 20:59
-
-
Save sidneyarcidiacono/a7d78dd52f010decfa9c6e997e202602 to your computer and use it in GitHub Desktop.
Building our model with pytorch-geometric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import everything we need to build our network: | |
from torch.nn import Linear | |
import torch.nn.functional as F | |
from torch_geometric.nn import GCNConv | |
from torch_geometric.nn import global_mean_pool | |
# Define our GCN class as a pytorch Module | |
class GCN(torch.nn.Module): | |
def __init__(self, hidden_channels): | |
super(GCN, self).__init__() | |
# We inherit from pytorch geometric's GCN class, and we initialize three layers | |
self.conv1 = GCNConv(dataset.num_node_features, hidden_channels) | |
self.conv2 = GCNConv(hidden_channels, hidden_channels) | |
self.conv3 = GCNConv(hidden_channels, hidden_channels) | |
# Our final linear layer will define our output | |
self.lin = Linear(hidden_channels, dataset.num_classes) | |
def forward(self, x, edge_index, batch): | |
# 1. Obtain node embeddings | |
x = self.conv1(x, edge_index) | |
x = x.relu() | |
x = self.conv2(x, edge_index) | |
x = x.relu() | |
x = self.conv3(x, edge_index) | |
# 2. Readout layer | |
x = global_mean_pool(x, batch) # [batch_size, hidden_channels] | |
# 3. Apply a final classifier | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = self.lin(x) | |
return x | |
model = GCN(hidden_channels=64) | |
print(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment