Skip to content

Instantly share code, notes, and snippets.

@rusty1s
Created May 22, 2020 07:42
Show Gist options
  • Save rusty1s/866ca8b830f22aa34923a0f8164f6e64 to your computer and use it in GitHub Desktop.
Save rusty1s/866ca8b830f22aa34923a0f8164f6e64 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch_geometric.utils import grid
from torch_geometric.nn import SplineConv
train_dataset = MNIST('/tmp/MNIST', train=True, transform=ToTensor())
test_dataset = MNIST('/tmp/MNIST', train=False, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,
drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, drop_last=True)
def to_batch(edge_index, pos, batch_size):
edge_indices = [edge_index + pos.size(0) * i for i in range(batch_size)]
edge_index = torch.cat(edge_indices, dim=1)
pos = torch.cat([pos] * batch_size, dim=0)
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
return edge_index, (edge_attr + 1.) / 2.
edge_index1, edge_attr1 = to_batch(*grid(28, 28), batch_size=64)
edge_index2, edge_attr2 = to_batch(*grid(14, 14), batch_size=64)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = SplineConv(1, 32, dim=2, kernel_size=3)
self.conv2 = SplineConv(32, 64, dim=2, kernel_size=3)
self.fc1 = torch.nn.Linear(3136, 512)
self.fc2 = torch.nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 1)
x = F.elu(self.conv1(x.view(-1, 1), edge_index1, edge_attr1))
x = x.view(64, 28, 28, 32).permute(0, 3, 1, 2)
x = F.max_pool2d(x, kernel_size=2)
x = x.permute(0, 2, 3, 1).contiguous().view(-1, 32)
x = F.elu(self.conv2(x, edge_index2, edge_attr2))
x = x.view(64, 14, 14, 64).permute(0, 3, 1, 2)
x = F.max_pool2d(x, kernel_size=2)
x = x.contiguous().view(64, -1)
x = F.elu(self.fc1(x))
return F.log_softmax(self.fc2(x), dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
edge_index1, edge_attr1 = edge_index1.to(device), edge_attr1.to(device)
edge_index2, edge_attr2 = edge_index2.to(device), edge_attr2.to(device)
def train(epoch):
model.train()
for i, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = F.nll_loss(model(x), y)
loss.backward()
optimizer.step()
print(i, len(train_loader), loss.item())
def test():
model.eval()
correct = 0
for i, (x, y) in enumerate(test_loader):
x, y = x.to(device), y.to(device)
pred = model(x).max(1)[1]
correct += pred.eq(y).sum().item()
return correct / len(test_dataset)
for epoch in range(1, 50):
train(epoch)
test_acc = test()
print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
@adhikarirsr
Copy link

What's two-hop neighbors? Is it similar to strides in normal CNN? How do I add that to the grid?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment