Skip to content

Instantly share code, notes, and snippets.

@rusty1s
Created May 22, 2020 07:42
Show Gist options
  • Save rusty1s/866ca8b830f22aa34923a0f8164f6e64 to your computer and use it in GitHub Desktop.
Save rusty1s/866ca8b830f22aa34923a0f8164f6e64 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch_geometric.utils import grid
from torch_geometric.nn import SplineConv
train_dataset = MNIST('/tmp/MNIST', train=True, transform=ToTensor())
test_dataset = MNIST('/tmp/MNIST', train=False, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,
drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, drop_last=True)
def to_batch(edge_index, pos, batch_size):
edge_indices = [edge_index + pos.size(0) * i for i in range(batch_size)]
edge_index = torch.cat(edge_indices, dim=1)
pos = torch.cat([pos] * batch_size, dim=0)
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
return edge_index, (edge_attr + 1.) / 2.
edge_index1, edge_attr1 = to_batch(*grid(28, 28), batch_size=64)
edge_index2, edge_attr2 = to_batch(*grid(14, 14), batch_size=64)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = SplineConv(1, 32, dim=2, kernel_size=3)
self.conv2 = SplineConv(32, 64, dim=2, kernel_size=3)
self.fc1 = torch.nn.Linear(3136, 512)
self.fc2 = torch.nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 1)
x = F.elu(self.conv1(x.view(-1, 1), edge_index1, edge_attr1))
x = x.view(64, 28, 28, 32).permute(0, 3, 1, 2)
x = F.max_pool2d(x, kernel_size=2)
x = x.permute(0, 2, 3, 1).contiguous().view(-1, 32)
x = F.elu(self.conv2(x, edge_index2, edge_attr2))
x = x.view(64, 14, 14, 64).permute(0, 3, 1, 2)
x = F.max_pool2d(x, kernel_size=2)
x = x.contiguous().view(64, -1)
x = F.elu(self.fc1(x))
return F.log_softmax(self.fc2(x), dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
edge_index1, edge_attr1 = edge_index1.to(device), edge_attr1.to(device)
edge_index2, edge_attr2 = edge_index2.to(device), edge_attr2.to(device)
def train(epoch):
model.train()
for i, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = F.nll_loss(model(x), y)
loss.backward()
optimizer.step()
print(i, len(train_loader), loss.item())
def test():
model.eval()
correct = 0
for i, (x, y) in enumerate(test_loader):
x, y = x.to(device), y.to(device)
pred = model(x).max(1)[1]
correct += pred.eq(y).sum().item()
return correct / len(test_dataset)
for epoch in range(1, 50):
train(epoch)
test_acc = test()
print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
@rusty1s
Copy link
Author

rusty1s commented May 28, 2020

For the grid experiment, we use normal 2D max pooling since we cannot flatten the output features otherwise. To increase accuracy, adding two-hop neighbors to the grid helps. contiguous needs to be called since permute permutes the dimensions of the node features, and GNN operators generally expect inputs with contiguous memory layout.

@adhikarirsr
Copy link

What's two-hop neighbors? Is it similar to strides in normal CNN? How do I add that to the grid?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment