Skip to content

Instantly share code, notes, and snippets.

@TheExGenesis
Last active September 30, 2021 17:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TheExGenesis/590e6363b7e13400ff1e3efdb80f8ef3 to your computer and use it in GitHub Desktop.
Save TheExGenesis/590e6363b7e13400ff1e3efdb80f8ef3 to your computer and use it in GitHub Desktop.
Training a GNN to predict degree*strat
# https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.py
#%%
from random import randint
import torch
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import dgl
class GCN(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout):
super(GCN, self).__init__()
self.layers = nn.ModuleList()
# input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout)
def forward(self, g):
h = g.ndata["features"]
dgl.add_self_loop(g)
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(g, h)
return h
def gen_random_graph(n_nodes, n_edges):
"""
generate a random dgl graph with n_nodes nodes and n_edges edges, with the following ndata properties:
strat: either 0 or 1
degree: the node's degree
label: degree * label
"""
g = dgl.rand_graph(n_nodes, n_edges)
g.ndata["strat"] = torch.tensor([randint(0, 1) for _ in range(n_nodes)]).float()
g.ndata["degree"] = torch.tensor([g.in_degrees(i) for i in range(n_nodes)]).float()
g.ndata["features"] = torch.stack(
(g.ndata["degree"], g.ndata["strat"]), axis=1
).float()
g.ndata["labels"] = g.ndata["degree"] * g.ndata["strat"]
return g
#%%
import argparse
import time
import numpy as np
import torch
import torch.nn.functional as F
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
# from gcn_mp import GCN
# from gcn_spmv import GCN
def train_loop(model, train_data, loss_fcn, optimizer, num_epochs=1):
dur = []
for epoch in range(num_epochs):
for step, g in enumerate(train_data):
model.train() # sets training mode rather than evaluation
if step >= 3:
t0 = time.time()
# forward
logits = model(g)
loss = loss_fcn(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step >= 3:
dur.append(time.time() - t0)
print(
f"Epoch: {epoch} Step {step} | Time(s) {np.mean(dur)} | Loss {loss.item()} | ETputs(KTEPS) {n_edges / np.mean(dur) / 1000}"
)
#%%
from torch import nn, tensor
# load and preprocess dataset
# data = CoraGraphDataset()
# g = data[0]
dataset_size = 1000
n_nodes, n_edges = 10, 70
training_data = [
gen_random_graph(n_nodes, n_edges) for _ in range(int(dataset_size * 0.7))
]
test_data = [gen_random_graph(n_nodes, n_edges) for _ in range(int(dataset_size * 0.2))]
val_data = [gen_random_graph(n_nodes, n_edges) for _ in range(int(dataset_size * 0.1))]
g = training_data[0]
features = g.ndata["features"]
labels = g.ndata["labels"]
in_feats = features.shape[1]
n_classes = 1
n_edges = g.number_of_edges()
# normalization
# ?
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata["norm"] = norm.unsqueeze(1)
n_hidden = 32
n_layers = 2
dropout = 0.5
# dropout = 0
model = GCN(in_feats, n_hidden, n_classes, n_layers, F.relu, dropout)
loss_fcn = torch.nn.MSELoss()
# loss_fcn = torch.nn.CrossEntropyLoss()
# loss_fcn = torch.nn.L1Loss()
lr = 0.01
weight_decay = 5e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
#%%
train_loop(model, training_data, loss_fcn, optimizer, num_epochs=2)
#%%
model.train() # sets training mode rather than evaluation
features = g.ndata["features"]
g = training_data[randint(0, len(training_data))]
logits = model(g)
labels = g.ndata["labels"]
loss = loss_fcn(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
logits, labels
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment