Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reading in the PROTEINS dataset | |
from spektral.datasets import TUDataset | |
# Spectral provides the TUDataset class, which contains benchmark datasets for graph classification | |
data = TUDataset('PROTEINS') | |
data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Since we want to utilize the Spektral GCN layer, we want to follow the original paper for this method and perform some preprocessing: | |
from spektral.transforms import GCNFilter | |
data.apply(GCNFilter()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Split our train and test data. This just splits based on the first 80%/second 20% which isn't entirely ideal, so we'll shuffle the data first. | |
import numpy as np | |
np.random.shuffle(data) | |
split = int(0.8 * len(data)) | |
data_train, data_test = data[:split], data[split:] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Spektral is built on top of Keras, so we can use the Keras functional API to build a model that first embeds, | |
# then sums the nodes together (global pooling), then classifies the result with a dense softmax layer | |
# First, let's import the necessary layers: | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Dense, Dropout | |
from spektral.layers import GCNConv, GlobalSumPool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Now, we can use model subclassing to define our model: | |
class ProteinsGNN(Model): | |
def __init__(self, n_hidden, n_labels): | |
super().__init__() | |
# Define our GCN layer with our n_hidden layers | |
self.graph_conv = GCNConv(n_hidden) | |
# Define our global pooling layer | |
self.pool = GlobalSumPool() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Instantiate our model for training | |
model = ProteinsGNN(32, data.n_labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Compile model with our optimizer (adam) and loss function | |
model.compile('adam', 'categorical_crossentropy') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Here's the trick - we can't just call Keras' fit() method on this model. | |
# Instead, we have to use Loaders, which Spektral walks us through. Loaders create mini-batches by iterating over the graph | |
# Since we're using Spektral for an experiment, for our first trial we'll use the recommended loader in the getting started tutorial | |
# TODO: read up on modes and try other loaders later | |
from spektral.data import BatchLoader | |
loader = BatchLoader(data_train, batch_size=32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!pip install spektral |
OlderNewer