Created
September 26, 2021 02:02
-
-
Save Rhett-Ying/39d350047e936e6a1e5ac760d6c8306a to your computer and use it in GitHub Desktop.
dgl_discuss_2351
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dgl | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import dgl.data | |
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes. | |
dataset = dgl.data.GINDataset('PROTEINS', self_loop=True) | |
print('Node feature dimensionality:', dataset.dim_nfeats) | |
print('Number of graph categories:', dataset.gclasses) | |
from dgl.dataloading import GraphDataLoader | |
from torch.utils.data.sampler import SubsetRandomSampler | |
num_examples = len(dataset) | |
num_train = int(num_examples * 0.8) | |
train_sampler = SubsetRandomSampler(torch.arange(num_train)) | |
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples)) | |
train_dataloader = GraphDataLoader( | |
dataset, sampler=train_sampler, batch_size=5, drop_last=False) | |
test_dataloader = GraphDataLoader( | |
dataset, sampler=test_sampler, batch_size=5, drop_last=False) | |
it = iter(train_dataloader) | |
batch = next(it) | |
print(batch) | |
batched_graph, labels = batch | |
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes()) | |
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges()) | |
# Recover the original graph elements from the minibatch | |
graphs = dgl.unbatch(batched_graph) | |
print('The original graphs in the minibatch:') | |
print(graphs) | |
from dgl.nn import GraphConv, GATConv | |
class GCN(nn.Module): | |
def __init__(self, in_feats, h_feats, num_classes): | |
super(GCN, self).__init__() | |
self.conv1 = GraphConv(in_feats, h_feats) | |
self.conv2 = GraphConv(h_feats, num_classes) | |
def forward(self, g, in_feat): | |
h = self.conv1(g, in_feat) | |
h = F.relu(h) | |
h = self.conv2(g, h) | |
g.ndata['h'] = h | |
x = dgl.mean_nodes(g, 'h') | |
print("--------- x.shape: {}, g: {}".format(x.shape, g)) | |
return x | |
class GAT(nn.Module): | |
def __init__(self, in_feats, h_feats, num_classes, num_heads=3): | |
super(GAT, self).__init__() | |
self.conv1 = GATConv(in_feats, h_feats, num_heads=num_heads) | |
self.conv2 = GATConv(h_feats, h_feats, num_heads) | |
self.conv3 = torch.nn.Linear(in_features=h_feats, out_features=num_classes) | |
def forward(self, g, in_feat): | |
h = self.conv1(g, in_feat) | |
h = torch.mean(h, dim=1) | |
h = self.conv2(g, h) | |
h = torch.mean(h, dim=1) | |
h = self.conv3(h) | |
g.ndata['h'] = h | |
x = dgl.mean_nodes(g, 'h') | |
print("--------- h.shape: {}, x.shape: {}, g.batch_size: {}".format(h.shape, x.shape, g.batch_size)) | |
return x | |
# Create the model with given dimensions | |
#model = GCN(dataset.dim_nfeats, 16, dataset.gclasses) | |
model = GAT(dataset.dim_nfeats, 16, dataset.gclasses) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
for epoch in range(1): | |
for batched_graph, labels in train_dataloader: | |
pred = model(batched_graph, batched_graph.ndata['attr'].float()) | |
loss = F.cross_entropy(pred, labels) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment