-
-
Save TheExGenesis/9af02ec2c1ed4b3fe6ec96d277f187b6 to your computer and use it in GitHub Desktop.
Regression with GAT, can't learn features of own node
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
"""GCN using DGL nn package | |
References: | |
- Semi-Supervised Classification with Graph Convolutional Networks | |
- Paper: https://arxiv.org/abs/1609.02907 | |
- Code: https://github.com/tkipf/gcn | |
""" | |
from random import randint, random | |
import torch | |
import torch.nn as nn | |
from dgl.nn.pytorch import GraphConv, GATConv | |
from ray.rllib.models.torch.misc import SlimFC | |
class GCN(nn.Module): | |
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): | |
super(GCN, self).__init__() | |
# self.g = g | |
self.layers = nn.ModuleList() | |
# input layer | |
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) | |
# hidden layers | |
for i in range(n_layers - 1): | |
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) | |
# output layer | |
# self.layers.append(GraphConv(n_hidden, n_classes)) | |
self.linear = SlimFC(in_size=n_hidden, out_size=1) | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, g): | |
h = g.ndata["feat"] | |
for i, layer in enumerate(self.layers): | |
# if i != 0: | |
# h = self.dropout(h) | |
# h = layer(self.g, h) | |
h = layer(g, h) | |
score = self.linear(h) | |
return score | |
class GAT(nn.Module): | |
def __init__( | |
self, | |
in_feats, | |
n_hidden, | |
n_classes, | |
n_layers, | |
activation, | |
heads=None, | |
feat_drop=0, | |
attn_drop=0, | |
negative_slope=0.2, | |
residual=False, | |
): | |
super(GAT, self).__init__() | |
if not heads: | |
heads = ([1] * n_layers) + [1] | |
self.n_layers = n_layers | |
self.gat_layers = nn.ModuleList() | |
self.activation = activation | |
# input projection (no residual) | |
self.gat_layers.append( | |
GATConv( | |
in_feats, | |
n_hidden, | |
heads[0], | |
feat_drop, | |
attn_drop, | |
negative_slope, | |
False, | |
self.activation, | |
) | |
) | |
# hidden layers | |
for l in range(1, n_layers): | |
# due to multi-head, the in_feats = n_hidden * n_heads | |
self.gat_layers.append( | |
GATConv( | |
n_hidden * heads[l - 1], | |
n_hidden, | |
heads[l], | |
feat_drop, | |
attn_drop, | |
negative_slope, | |
residual, | |
self.activation, | |
) | |
) | |
# output projection | |
self.gat_layers.append( | |
GATConv( | |
n_hidden * heads[-2], | |
n_classes, | |
heads[-1], | |
feat_drop, | |
attn_drop, | |
negative_slope, | |
residual, | |
None, | |
) | |
) | |
def forward(self, g): | |
h = g.ndata["feat"] | |
for l in range(self.n_layers): | |
h = self.gat_layers[l](g, h).flatten(1) | |
# output projection | |
logits = self.gat_layers[-1](g, h).mean(1) | |
return logits | |
def train_test_val_mask(N, train_size, test_size, val_size): | |
"""returns 3 masks as binary np arrays""" | |
train_mask = torch.zeros(N) | |
train_mask[:train_size] = True | |
test_mask = torch.zeros(N) | |
test_mask[train_size : train_size + test_size] = True | |
val_mask = torch.zeros(N) | |
val_mask[train_size + test_size : train_size + test_size + val_size] = True | |
return train_mask.bool(), test_mask.bool(), val_mask.bool() | |
def gen_random_graph(n_nodes, n_edges): | |
""" | |
generate a random dgl graph with n_nodes nodes and n_edges edges, with the following ndata properties: | |
strat: either 0 or 1 | |
degree: the node's degree | |
label: degree * label | |
""" | |
g = dgl.rand_graph(n_nodes, n_edges) | |
g = dgl.remove_self_loop(g) | |
g = dgl.add_self_loop(g) | |
g.ndata["strat"] = torch.tensor([randint(0, 1) for _ in range(n_nodes)]).float() | |
# g.ndata["rand0"] = torch.tensor([random() for _ in range(n_nodes)]).float() | |
# g.ndata["rand1"] = torch.tensor([random() for _ in range(n_nodes)]).float() | |
g.ndata["degree"] = ( | |
torch.tensor([g.in_degrees(i) for i in range(n_nodes)]).float() / n_nodes | |
) | |
g.ndata["feat"] = torch.stack( | |
[g.ndata[name] for name in ["degree", "strat"]], | |
axis=1 | |
# [g.ndata[name] for name in ["degree", "strat", "rand0", "rand1"]], axis=1 | |
).float() | |
g.ndata["label"] = g.ndata["degree"] * g.ndata["strat"] | |
# g.ndata["label"] = (g.ndata["degree"] * g.ndata["strat"] > 7).long() # classif | |
N = n_nodes | |
( | |
g.ndata["train_mask"], | |
g.ndata["val_mask"], | |
g.ndata["test_mask"], | |
) = train_test_val_mask(N, int(N * 0.8), int(N * 0.1), int(N * 0.1)) | |
return g | |
import time | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import dgl | |
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset | |
# from gcn_mp import GCN | |
# from gcn_spmv import GCN | |
def calc_loss(logits, labels): | |
loss_fcn = torch.nn.MSELoss() | |
loss = loss_fcn(logits, labels) | |
return loss.item() | |
def evaluate(model, gs): | |
losses = [] | |
model.eval() | |
with torch.no_grad(): | |
for g in gs: | |
val_mask = g.ndata["val_mask"] | |
features = g.ndata["feat"] | |
labels = g.ndata["label"] | |
logits = model(g) | |
loss = loss_fcn(logits[val_mask].squeeze(), labels[val_mask]) | |
losses.append(loss.item()) | |
return np.mean(losses) | |
from types import SimpleNamespace | |
args = SimpleNamespace( | |
# dropout=0.5, | |
dropout=0, | |
gpu=0, | |
lr=0.01, | |
n_epochs=100, | |
n_hidden=32, | |
n_layers=1, | |
weight_decay=5e-4, | |
self_loop=True, | |
many_graphs=True, | |
) | |
in_feats = 2 | |
n_classes = 1 | |
# create GCN model | |
# model = GCN(in_feats, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout) | |
heads = ([3] * args.n_layers) + [1] | |
model = GAT(in_feats, args.n_hidden, n_classes, args.n_layers, F.relu, heads=heads) | |
loss_fcn = torch.nn.MSELoss() | |
# loss_fcn = torch.nn.CrossEntropyLoss() | |
# use optimizer | |
optimizer = torch.optim.Adam( | |
model.parameters(), lr=args.lr, weight_decay=args.weight_decay | |
) | |
#%% | |
# many graphs | |
num_graphs = 10 | |
gs = [gen_random_graph(50, 300) for _ in range(num_graphs)] | |
losses = [] | |
#%% | |
# one graph at a time, ignore for now | |
# for g in gs: | |
# features = g.ndata["feat"] | |
# labels = g.ndata["label"] | |
# train_mask = g.ndata["train_mask"] | |
# val_mask = g.ndata["val_mask"] | |
# test_mask = g.ndata["test_mask"] | |
# in_feats = features.shape[1] | |
# # n_classes = data.num_labels | |
# n_classes = n_classes | |
# logits = model(g) | |
# loss = loss_fcn(logits[train_mask], labels[train_mask]) | |
# losses.append(loss.item()) | |
# optimizer.zero_grad() | |
# loss.backward() | |
# optimizer.step() | |
# print(evaluate(model, gs)) | |
# print(evaluate(model, [g])) | |
#%% | |
# batched graph | |
num_graphs = 1000 | |
gs = [gen_random_graph(10, 70) for _ in range(num_graphs)] | |
g = dgl.batch(gs) | |
features = g.ndata["feat"] | |
labels = g.ndata["label"] | |
train_mask = g.ndata["train_mask"] | |
val_mask = g.ndata["val_mask"] | |
test_mask = g.ndata["test_mask"] | |
in_feats = features.shape[1] | |
# n_classes = data.num_labels | |
n_classes = 1 | |
#%% | |
for _ in range(100): | |
logits = model(g) | |
loss = loss_fcn(logits[train_mask].squeeze(), labels[train_mask]) | |
losses.append(loss.item()) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
#%% | |
logits = model(g) | |
print(f"logits {model(g).squeeze()[:10]}") | |
print(f"label {labels[:10] }") | |
print(f'label mean: {g.ndata["label"].mean()}') | |
print(f"batched loss {calc_loss(logits.squeeze(), labels)}") | |
print(f"many gs loss {evaluate(model, gs)}") | |
# %% | |
# viz | |
import networkx as nx | |
G = gs[1] | |
nx_G = G.to_networkx().to_undirected() | |
pos = nx.kamada_kawai_layout(nx_G) | |
h = model.gat_layers[0](G, G.ndata["feat"]).flatten(1) | |
r, e = model.gat_layers[1](G, h, get_attention=True) | |
edge_weights = e[:, 0, :].squeeze().detach().numpy() | |
nx.draw( | |
nx_G, | |
pos, | |
with_labels=True, | |
cmap="hot", | |
node_color=G.ndata["label"], | |
width=(edge_weights-edge_weights.min()) / edge_weights.max(), | |
) | |
print(edge_weights) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment