Created
February 19, 2021 11:25
-
-
Save joaquincabezas/aae4cd1da5653f692653165c415b775f to your computer and use it in GitHub Desktop.
Modified example of GNNExplainer usage with an adjusted coefficient for edge_size. It helps with stability and conciseness for explanations of small (10-100 edges) computation graphs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original example at | |
# https://github.com/rusty1s/pytorch_geometric/blob/master/examples/gnn_explainer.py | |
# In this modified version we take into account datasets where many nodes have a | |
# 2-hop neighbourhood comprising 10 to 100 edges. In theses cases, the current fixed | |
# value of the edge_size coefficient is not working as intented. We propose a | |
# variable coefficient, dependent on the number of edges of the 2-hop neighbourhood | |
# (we will consider this neighbourhood as the computation graph). | |
# More info at: https://github.com/rusty1s/pytorch_geometric/issues/1985 | |
# or at @joaquincabezas | |
import os.path as osp | |
import torch | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from torch_geometric.datasets import Planetoid | |
import torch_geometric.transforms as T | |
from torch_geometric.nn import GCNConv, GNNExplainer | |
# This import is needed for computing N (edges in the computation graph) | |
from torch_geometric.utils import k_hop_subgraph | |
dataset = 'Cora' | |
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | |
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) | |
data = dataset[0] | |
class Net(torch.nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = GCNConv(dataset.num_features, 16) | |
self.conv2 = GCNConv(16, dataset.num_classes) | |
def forward(self, x, edge_index): | |
x = F.relu(self.conv1(x, edge_index)) | |
x = F.dropout(x, training=self.training) | |
x = self.conv2(x, edge_index) | |
return F.log_softmax(x, dim=1) | |
# If we want to make reproducible experiments, remember to add here a fixed random seed | |
# seed = 1234 | |
# torch.manual_seed(seed) | |
# random.seed(seed) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = Net().to(device) | |
data = data.to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) | |
x, edge_index = data.x, data.edge_index | |
for epoch in range(1, 201): | |
model.train() | |
optimizer.zero_grad() | |
log_logits = model(x, edge_index) | |
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) | |
loss.backward() | |
optimizer.step() | |
# We create the explainer before | |
explainer = GNNExplainer(model, epochs=200, log=False) | |
# Node to be explained | |
node_idx = 10 | |
# We obtain the computation graph of node_idx | |
_, _, _, hard_edge_mask = k_hop_subgraph(node_idx, 2, edge_index) | |
# N is the number of edges within the computation graph | |
N = torch.sum(hard_edge_mask==True).item() | |
S = 5 # Objective size of the explanation | |
L = 0.5 # Objective contribution to the loss function (for edge_size) | |
# Adjust the coefficient for edge_size: | |
# This formula is obtained considering that the value of a relevant edges, | |
# at the end of the process is around 0.9, while the value of irrelevant | |
# edges is around 0.1 | |
# (https://github.com/rusty1s/pytorch_geometric/issues/1985) | |
explainer.coeffs['edge_size'] = 10*L/(N+8*S) | |
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index) | |
ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment