Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Created November 20, 2021 16:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bigsnarfdude/a6253d359cbb9091b4792e6df2b6cdcb to your computer and use it in GitHub Desktop.
Save bigsnarfdude/a6253d359cbb9091b4792e6df2b6cdcb to your computer and use it in GitHub Desktop.
updated singleloader cora dataset GCN using spektral
"""
This example implements the experiments on citation networks from the paper:
Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907)
Thomas N. Kipf, Max Welling
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.models.gcn import GCN
from spektral.transforms import LayerPreprocess
learning_rate = 1e-2
seed = 0
epochs = 200
patience = 10
data = "cora"
tf.random.set_seed(seed=seed) # make weight initialization reproducible
# Load data
dataset = Citation(data, normalize_x=True, transforms=[LayerPreprocess(GCNConv)])
# We convert the binary masks to sample weights so that we can compute the
# average loss over the nodes (following original implementation by
# Kipf & Welling)
def mask_to_weights(mask):
return mask.astype(np.float32) / np.count_nonzero(mask)
weights_tr, weights_va, weights_te = (
mask_to_weights(mask)
for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)
model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
model.compile(
optimizer=Adam(learning_rate),
loss=CategoricalCrossentropy(reduction="sum"),
weighted_metrics=["acc"],
)
# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
loader_tr.load(),
steps_per_epoch=loader_tr.steps_per_epoch,
validation_data=loader_va.load(),
validation_steps=loader_va.steps_per_epoch,
epochs=epochs,
callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)
# Evaluate model
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment