Skip to content

Instantly share code, notes, and snippets.

@rusty1s
Created May 22, 2020 07:42
Show Gist options
  • Save rusty1s/866ca8b830f22aa34923a0f8164f6e64 to your computer and use it in GitHub Desktop.
Save rusty1s/866ca8b830f22aa34923a0f8164f6e64 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch_geometric.utils import grid
from torch_geometric.nn import SplineConv
train_dataset = MNIST('/tmp/MNIST', train=True, transform=ToTensor())
test_dataset = MNIST('/tmp/MNIST', train=False, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,
drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, drop_last=True)
def to_batch(edge_index, pos, batch_size):
edge_indices = [edge_index + pos.size(0) * i for i in range(batch_size)]
edge_index = torch.cat(edge_indices, dim=1)
pos = torch.cat([pos] * batch_size, dim=0)
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
return edge_index, (edge_attr + 1.) / 2.
edge_index1, edge_attr1 = to_batch(*grid(28, 28), batch_size=64)
edge_index2, edge_attr2 = to_batch(*grid(14, 14), batch_size=64)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = SplineConv(1, 32, dim=2, kernel_size=3)
self.conv2 = SplineConv(32, 64, dim=2, kernel_size=3)
self.fc1 = torch.nn.Linear(3136, 512)
self.fc2 = torch.nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 1)
x = F.elu(self.conv1(x.view(-1, 1), edge_index1, edge_attr1))
x = x.view(64, 28, 28, 32).permute(0, 3, 1, 2)
x = F.max_pool2d(x, kernel_size=2)
x = x.permute(0, 2, 3, 1).contiguous().view(-1, 32)
x = F.elu(self.conv2(x, edge_index2, edge_attr2))
x = x.view(64, 14, 14, 64).permute(0, 3, 1, 2)
x = F.max_pool2d(x, kernel_size=2)
x = x.contiguous().view(64, -1)
x = F.elu(self.fc1(x))
return F.log_softmax(self.fc2(x), dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
edge_index1, edge_attr1 = edge_index1.to(device), edge_attr1.to(device)
edge_index2, edge_attr2 = edge_index2.to(device), edge_attr2.to(device)
def train(epoch):
model.train()
for i, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = F.nll_loss(model(x), y)
loss.backward()
optimizer.step()
print(i, len(train_loader), loss.item())
def test():
model.eval()
correct = 0
for i, (x, y) in enumerate(test_loader):
x, y = x.to(device), y.to(device)
pred = model(x).max(1)[1]
correct += pred.eq(y).sum().item()
return correct / len(test_dataset)
for epoch in range(1, 50):
train(epoch)
test_acc = test()
print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
@adhikarirsr
Copy link

In SplineCNN paper, you used Graclus based pooling rather than max_pool_2D, right?

After 50 epochs, I got around 98.8% accuracy, which is still not 99.33%. I used a filter size of 5. What's the idea behind using contiguous (is it for the reason described in the solution here: https://discuss.pytorch.org/t/when-and-why-do-we-use-contiguous/47588)? The reason you are using max_pool2D is that the graph is a grid graph here, right?

@rusty1s
Copy link
Author

rusty1s commented May 28, 2020

For the grid experiment, we use normal 2D max pooling since we cannot flatten the output features otherwise. To increase accuracy, adding two-hop neighbors to the grid helps. contiguous needs to be called since permute permutes the dimensions of the node features, and GNN operators generally expect inputs with contiguous memory layout.

@adhikarirsr
Copy link

What's two-hop neighbors? Is it similar to strides in normal CNN? How do I add that to the grid?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment