Skip to content

Instantly share code, notes, and snippets.

@lgray
Created March 10, 2020 20:12
Show Gist options
  • Save lgray/b768dc102eee102a0046aece8a23fa46 to your computer and use it in GitHub Desktop.
Save lgray/b768dc102eee102a0046aece8a23fa46 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
import torch.nn as nn
from torch_geometric.nn import EdgeConv, DynamicEdgeConv
#let's try a basic implementation of really simple message passing
from torch_scatter import scatter_add
class NodeNetwork(nn.Module):
def __init__(self, input_dim, output_dim, hidden_activation=nn.Tanh):
super(NodeNetwork, self).__init__()
self.nodec = nn.Sequential(
nn.Linear(input_dim * 3, output_dim),
hidden_activation(),
nn.Linear(output_dim, output_dim),
hidden_activation())
def forward(self, x, edge_index, edge_attr):
row = edge_index[:,0]
col = edge_index[:,1]
mi = x.new_zeros(x.shape)
mo = x.new_zeros(x.shape)
mi = scatter_add(edge_attr*x[row],col,dim=0,out=mi)
mo = scatter_add(edge_attr*x[col],row,dim=0,out=mo)
M = torch.cat([mi,mo,x],dim=-1)
return self.nodec(M)
node_test = NodeNetwork(5, 64)
print(node_test)
node_test_scripted = torch.jit.script(node_test)
print(node_test_scripted)
#now let's try a full network using the actual implmentations
class TestNet(nn.Module):
def __init__(self, input_dim=3, hidden_dim=8, output_dim=1, n_iters=1,aggr='add'):
super(TestNet, self).__init__()
convnn = nn.Sequential(nn.Linear(2*(hidden_dim + input_dim), hidden_dim),
nn.Sigmoid(),
nn.Linear(hidden_dim, hidden_dim),
nn.Sigmoid()
)
self.n_iters = n_iters
self.inputnet = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Tanh()
)
self.edgenetwork = nn.Sequential(nn.Linear(2*(hidden_dim+input_dim),output_dim),nn.Sigmoid())
self.nodenetwork = EdgeConv(nn=convnn,aggr=aggr)
def forward(self, data):
X = data.x
H = self.inputnet(X)
data.x = torch.cat([H,X],dim=-1)
for i in range(self.n_iters):
H = self.nodenetwork(data.x,data.edge_index)
data.x = torch.cat([H,X],dim=-1)
row,col = data.edge_index
return self.edgenetwork(torch.cat([data.x[row],data.x[col]],dim=-1)).squeeze(-1)
test = TestNet(input_dim=5, hidden_dim=64, output_dim=4, n_iters=6, aggr='add').to('cuda')
print(test)
test_scripted = torch.jit.script(test)
print(test_scripted)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment