Skip to content

Instantly share code, notes, and snippets.

@Rhett-Ying
Created September 26, 2021 02:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Rhett-Ying/39d350047e936e6a1e5ac760d6c8306a to your computer and use it in GitHub Desktop.
Save Rhett-Ying/39d350047e936e6a1e5ac760d6c8306a to your computer and use it in GitHub Desktop.
dgl_discuss_2351
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)
print('Node feature dimensionality:', dataset.dim_nfeats)
print('Number of graph categories:', dataset.gclasses)
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler
num_examples = len(dataset)
num_train = int(num_examples * 0.8)
train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))
train_dataloader = GraphDataLoader(
dataset, sampler=train_sampler, batch_size=5, drop_last=False)
test_dataloader = GraphDataLoader(
dataset, sampler=test_sampler, batch_size=5, drop_last=False)
it = iter(train_dataloader)
batch = next(it)
print(batch)
batched_graph, labels = batch
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes())
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges())
# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:')
print(graphs)
from dgl.nn import GraphConv, GATConv
class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
g.ndata['h'] = h
x = dgl.mean_nodes(g, 'h')
print("--------- x.shape: {}, g: {}".format(x.shape, g))
return x
class GAT(nn.Module):
def __init__(self, in_feats, h_feats, num_classes, num_heads=3):
super(GAT, self).__init__()
self.conv1 = GATConv(in_feats, h_feats, num_heads=num_heads)
self.conv2 = GATConv(h_feats, h_feats, num_heads)
self.conv3 = torch.nn.Linear(in_features=h_feats, out_features=num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = torch.mean(h, dim=1)
h = self.conv2(g, h)
h = torch.mean(h, dim=1)
h = self.conv3(h)
g.ndata['h'] = h
x = dgl.mean_nodes(g, 'h')
print("--------- h.shape: {}, x.shape: {}, g.batch_size: {}".format(h.shape, x.shape, g.batch_size))
return x
# Create the model with given dimensions
#model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
model = GAT(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1):
for batched_graph, labels in train_dataloader:
pred = model(batched_graph, batched_graph.ndata['attr'].float())
loss = F.cross_entropy(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment