Skip to content

Instantly share code, notes, and snippets.

@mkocabas
Created January 23, 2019 11:12
Show Gist options
  • Save mkocabas/a4bc01b7ee4076eb1d5af095caa70c1f to your computer and use it in GitHub Desktop.
Save mkocabas/a4bc01b7ee4076eb1d5af095caa70c1f to your computer and use it in GitHub Desktop.
import dgl
import torch
import torch.nn as nn
def build_pose_graph():
g = dgl.DGLGraph()
g.add_nodes(16)
edge_list = [(0, 1), (1, 2), (2, 6), (6, 3), (3, 4),
(4, 5), (6, 7), (7, 8), (8, 9), (8, 12),
(12, 11), (11, 10), (8, 13), (13, 14), (14, 15)]
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
g.add_edges(dst, src)
return g
def gcn_message(edges):
return {'msg': edges.src['h']}
def gcn_reduce(nodes):
return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats, num_nodes):
super(GCNLayer, self).__init__()
self.linears = nn.ModuleList([nn.Linear(in_feats, out_feats) for i in range(num_nodes)])
def forward(self, g, inputs):
g.ndata['h'] = inputs
g.send(g.edges(), gcn_message)
g.recv(g.nodes(), gcn_reduce)
h = g.ndata['h']
h = torch.stack([self.linears[i](x) for i,x in enumerate(h)])
return h
class GCN(nn.Module):
def __init__(self, in_feats, hidden_size, num_classes, num_nodes):
super(GCN, self).__init__()
self.gcn1 = GCNLayer(in_feats, hidden_size, num_nodes)
self.gcn2 = GCNLayer(hidden_size, num_classes, num_nodes)
def forward(self, g, inputs):
h = self.gcn1(g, inputs)
h = torch.relu(h)
h = self.gcn2(g, h)
return h
def main():
G = build_pose_graph()
print('We have %d nodes.' % G.number_of_nodes())
print('We have %d edges.' % G.number_of_edges())
G.ndata['feat'] = torch.randn(16,2)
print(G.nodes())
net = GCN(2,16,3,16)
output = net(G, torch.randn(16,2))
print(output.shape)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment