Skip to content

Instantly share code, notes, and snippets.

@lgray
Last active March 11, 2020 20:32
Show Gist options
  • Save lgray/5267e210d2a900b99e85b261ce6e2cfd to your computer and use it in GitHub Desktop.
Save lgray/5267e210d2a900b99e85b261ce6e2cfd to your computer and use it in GitHub Desktop.
import os
import os.path as osp
import math
import numpy as np
import torch
import gc
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.utils.checkpoint import checkpoint
from torch_cluster import knn_graph
from torch_geometric.nn import EdgeConv
from torch_geometric.utils import normalized_cut
from torch_geometric.utils import remove_self_loops
from torch_geometric.utils.undirected import to_undirected
from torch_geometric.nn import (graclus, max_pool, max_pool_x,
global_mean_pool, global_max_pool,
global_add_pool)
transform = T.Cartesian(cat=False)
def normalized_cut_2d(edge_index, pos):
row, col = edge_index
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
class DynamicReductionNetwork(nn.Module):
# This model clusters nearest neighbour graphs
# in two steps.
# The latent space trained to group useful features at each level
# of aggregration.
# This allows single quantities to be regressed from complex point counts
# in a location and orientation invariant way.
# One encoding layer is used to abstract away the input features.
def __init__(self, input_dim=5, hidden_dim=64, output_dim=1, k=16, aggr='add',
norm=torch.tensor([1./500., 1./500., 1./54., 1/25., 1./1000.])):
super(DynamicReductionNetwork, self).__init__()
self.datanorm = nn.Parameter(norm)
self.k = k
start_width = 2 * hidden_dim
middle_width = 3 * hidden_dim // 2
self.inputnet = nn.Sequential(
nn.Linear(input_dim, hidden_dim//2),
nn.ELU(),
nn.Linear(hidden_dim//2, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ELU()
)
convnn1 = nn.Sequential(nn.Linear(start_width, middle_width),
nn.ELU(),
nn.Linear(middle_width, hidden_dim),
nn.ELU()
)
convnn2 = nn.Sequential(nn.Linear(start_width, middle_width),
nn.ELU(),
nn.Linear(middle_width, hidden_dim),
nn.ELU()
)
self.edgeconv1 = EdgeConv(nn=convnn1, aggr=aggr)
self.edgeconv2 = EdgeConv(nn=convnn2, aggr=aggr)
self.output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim//2),
nn.ELU(),
nn.Linear(hidden_dim//2, output_dim))
def forward(self, data):
data.x = self.datanorm * data.x
data.x = self.inputnet(data.x)
data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow))
data.x = self.edgeconv1(data.x, data.edge_index)
weight = normalized_cut_2d(data.edge_index, data.x)
cluster = graclus(data.edge_index, weight, data.x.size(0))
data.edge_attr = None
data = max_pool(cluster, data)
data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow))
data.x = self.edgeconv2(data.x, data.edge_index)
weight = normalized_cut_2d(data.edge_index, data.x)
cluster = graclus(data.edge_index, weight, data.x.size(0))
x, batch = max_pool_x(cluster, data.x, data.batch)
x = global_max_pool(x, batch)
return self.output(x).squeeze(-1)
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
import warnings
warnings.simplefilter('ignore')
path = osp.join('./', '..', 'data', 'MNIST')
transform = T.Cartesian(cat=False)
train_dataset = MNISTSuperpixels(path, True, transform=transform)
test_dataset = MNISTSuperpixels(path, False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
d = train_dataset
print('features ->', d.num_features)
print('classes ->',d.num_classes)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.drn = DynamicReductionNetwork(input_dim=3, hidden_dim=256,
k=16,
output_dim=d.num_classes, aggr='add',
norm=torch.tensor([1., 1./27., 1./27.]))
def forward(self, data):
logits = self.drn(data)
return F.log_softmax(logits, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
def train(epoch):
model.train()
if epoch == 16:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0005
if epoch == 32:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0001
if epoch == 48:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.00005
for data in train_loader:
data = data.to(device)
mask = (data.x > 0.).squeeze()
data.x = torch.cat([data.x, data.pos], dim=-1)
data.x = data.x[mask,:]
#print(data.x)
data.pos = data.pos[mask,:]
data.batch = data.batch[mask.squeeze()]
optimizer.zero_grad()
result = model(data)
loss = F.nll_loss(result, data.y)
loss.backward()
#print(torch.unique(torch.argmax(result, dim=-1)))
#print(torch.unique(data.y))
optimizer.step()
def test():
model.eval()
correct = 0
for data in test_loader:
data = data.to(device)
mask = (data.x > 0.).squeeze()
data.x = torch.cat([data.x, data.pos], dim=-1)
data.x = data.x[mask,:]
#print(data.x)
data.pos = data.pos[mask,:]
data.batch = data.batch[mask.squeeze()]
pred = model(data).max(1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(test_dataset)
for epoch in range(1, 65):
train(epoch)
test_acc = test()
print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment