Created
November 20, 2021 16:10
-
-
Save bigsnarfdude/a6253d359cbb9091b4792e6df2b6cdcb to your computer and use it in GitHub Desktop.
updated singleloader cora dataset GCN using spektral
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This example implements the experiments on citation networks from the paper: | |
Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907) | |
Thomas N. Kipf, Max Welling | |
""" | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.callbacks import EarlyStopping | |
from tensorflow.keras.losses import CategoricalCrossentropy | |
from tensorflow.keras.optimizers import Adam | |
from spektral.data.loaders import SingleLoader | |
from spektral.datasets.citation import Citation | |
from spektral.layers import GCNConv | |
from spektral.models.gcn import GCN | |
from spektral.transforms import LayerPreprocess | |
learning_rate = 1e-2 | |
seed = 0 | |
epochs = 200 | |
patience = 10 | |
data = "cora" | |
tf.random.set_seed(seed=seed) # make weight initialization reproducible | |
# Load data | |
dataset = Citation(data, normalize_x=True, transforms=[LayerPreprocess(GCNConv)]) | |
# We convert the binary masks to sample weights so that we can compute the | |
# average loss over the nodes (following original implementation by | |
# Kipf & Welling) | |
def mask_to_weights(mask): | |
return mask.astype(np.float32) / np.count_nonzero(mask) | |
weights_tr, weights_va, weights_te = ( | |
mask_to_weights(mask) | |
for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te) | |
) | |
model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features) | |
model.compile( | |
optimizer=Adam(learning_rate), | |
loss=CategoricalCrossentropy(reduction="sum"), | |
weighted_metrics=["acc"], | |
) | |
# Train model | |
loader_tr = SingleLoader(dataset, sample_weights=weights_tr) | |
loader_va = SingleLoader(dataset, sample_weights=weights_va) | |
model.fit( | |
loader_tr.load(), | |
steps_per_epoch=loader_tr.steps_per_epoch, | |
validation_data=loader_va.load(), | |
validation_steps=loader_va.steps_per_epoch, | |
epochs=epochs, | |
callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)], | |
) | |
# Evaluate model | |
print("Evaluating model.") | |
loader_te = SingleLoader(dataset, sample_weights=weights_te) | |
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch) | |
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment