Last active
March 11, 2020 20:32
-
-
Save lgray/5267e210d2a900b99e85b261ce6e2cfd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import os.path as osp | |
import math | |
import numpy as np | |
import torch | |
import gc | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch_geometric.transforms as T | |
from torch.utils.checkpoint import checkpoint | |
from torch_cluster import knn_graph | |
from torch_geometric.nn import EdgeConv | |
from torch_geometric.utils import normalized_cut | |
from torch_geometric.utils import remove_self_loops | |
from torch_geometric.utils.undirected import to_undirected | |
from torch_geometric.nn import (graclus, max_pool, max_pool_x, | |
global_mean_pool, global_max_pool, | |
global_add_pool) | |
transform = T.Cartesian(cat=False) | |
def normalized_cut_2d(edge_index, pos): | |
row, col = edge_index | |
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) | |
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) | |
class DynamicReductionNetwork(nn.Module): | |
# This model clusters nearest neighbour graphs | |
# in two steps. | |
# The latent space trained to group useful features at each level | |
# of aggregration. | |
# This allows single quantities to be regressed from complex point counts | |
# in a location and orientation invariant way. | |
# One encoding layer is used to abstract away the input features. | |
def __init__(self, input_dim=5, hidden_dim=64, output_dim=1, k=16, aggr='add', | |
norm=torch.tensor([1./500., 1./500., 1./54., 1/25., 1./1000.])): | |
super(DynamicReductionNetwork, self).__init__() | |
self.datanorm = nn.Parameter(norm) | |
self.k = k | |
start_width = 2 * hidden_dim | |
middle_width = 3 * hidden_dim // 2 | |
self.inputnet = nn.Sequential( | |
nn.Linear(input_dim, hidden_dim//2), | |
nn.ELU(), | |
nn.Linear(hidden_dim//2, hidden_dim), | |
nn.ELU(), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ELU() | |
) | |
convnn1 = nn.Sequential(nn.Linear(start_width, middle_width), | |
nn.ELU(), | |
nn.Linear(middle_width, hidden_dim), | |
nn.ELU() | |
) | |
convnn2 = nn.Sequential(nn.Linear(start_width, middle_width), | |
nn.ELU(), | |
nn.Linear(middle_width, hidden_dim), | |
nn.ELU() | |
) | |
self.edgeconv1 = EdgeConv(nn=convnn1, aggr=aggr) | |
self.edgeconv2 = EdgeConv(nn=convnn2, aggr=aggr) | |
self.output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), | |
nn.ELU(), | |
nn.Linear(hidden_dim, hidden_dim//2), | |
nn.ELU(), | |
nn.Linear(hidden_dim//2, output_dim)) | |
def forward(self, data): | |
data.x = self.datanorm * data.x | |
data.x = self.inputnet(data.x) | |
data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) | |
data.x = self.edgeconv1(data.x, data.edge_index) | |
weight = normalized_cut_2d(data.edge_index, data.x) | |
cluster = graclus(data.edge_index, weight, data.x.size(0)) | |
data.edge_attr = None | |
data = max_pool(cluster, data) | |
data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) | |
data.x = self.edgeconv2(data.x, data.edge_index) | |
weight = normalized_cut_2d(data.edge_index, data.x) | |
cluster = graclus(data.edge_index, weight, data.x.size(0)) | |
x, batch = max_pool_x(cluster, data.x, data.batch) | |
x = global_max_pool(x, batch) | |
return self.output(x).squeeze(-1) | |
import os.path as osp | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_geometric.datasets import MNISTSuperpixels | |
import torch_geometric.transforms as T | |
from torch_geometric.data import DataLoader | |
import warnings | |
warnings.simplefilter('ignore') | |
path = osp.join('./', '..', 'data', 'MNIST') | |
transform = T.Cartesian(cat=False) | |
train_dataset = MNISTSuperpixels(path, True, transform=transform) | |
test_dataset = MNISTSuperpixels(path, False, transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False) | |
d = train_dataset | |
print('features ->', d.num_features) | |
print('classes ->',d.num_classes) | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.drn = DynamicReductionNetwork(input_dim=3, hidden_dim=256, | |
k=16, | |
output_dim=d.num_classes, aggr='add', | |
norm=torch.tensor([1., 1./27., 1./27.])) | |
def forward(self, data): | |
logits = self.drn(data) | |
return F.log_softmax(logits, dim=1) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = Net().to(device) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) | |
def train(epoch): | |
model.train() | |
if epoch == 16: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = 0.0005 | |
if epoch == 32: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = 0.0001 | |
if epoch == 48: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = 0.00005 | |
for data in train_loader: | |
data = data.to(device) | |
mask = (data.x > 0.).squeeze() | |
data.x = torch.cat([data.x, data.pos], dim=-1) | |
data.x = data.x[mask,:] | |
#print(data.x) | |
data.pos = data.pos[mask,:] | |
data.batch = data.batch[mask.squeeze()] | |
optimizer.zero_grad() | |
result = model(data) | |
loss = F.nll_loss(result, data.y) | |
loss.backward() | |
#print(torch.unique(torch.argmax(result, dim=-1))) | |
#print(torch.unique(data.y)) | |
optimizer.step() | |
def test(): | |
model.eval() | |
correct = 0 | |
for data in test_loader: | |
data = data.to(device) | |
mask = (data.x > 0.).squeeze() | |
data.x = torch.cat([data.x, data.pos], dim=-1) | |
data.x = data.x[mask,:] | |
#print(data.x) | |
data.pos = data.pos[mask,:] | |
data.batch = data.batch[mask.squeeze()] | |
pred = model(data).max(1)[1] | |
correct += pred.eq(data.y).sum().item() | |
return correct / len(test_dataset) | |
for epoch in range(1, 65): | |
train(epoch) | |
test_acc = test() | |
print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment