Skip to content

Instantly share code, notes, and snippets.

Created November 7, 2019 12:44
Show Gist options
  • Save mkocabas/8b45965efebfd0f75e21736c93bf158c to your computer and use it in GitHub Desktop.
Save mkocabas/8b45965efebfd0f75e21736c93bf158c to your computer and use it in GitHub Desktop.
import dgl
import torch
import torch.nn as nn
import time
def build_pose_graph():
g = dgl.DGLGraph()
edge_list = [(0, 1), (1, 2), (2, 6), (6, 3), (3, 4),
(4, 5), (6, 7), (7, 8), (8, 9), (8, 12),
(12, 11), (11, 10), (8, 13), (13, 14), (14, 15)]
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
g.add_edges(dst, src)
g.add_edges(g.nodes().tolist(), g.nodes().tolist()) # Self loop
return g
class ActivationLayer(nn.Module):
def __init__(self, size, p_dropout=0.5):
super(ActivationLayer, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p_dropout)
self.batch_norm = nn.BatchNorm1d(size)
def forward(self, x):
x_shape = x.shape
x = torch.t(x.reshape([-1, x_shape[2]]))
x = self.batch_norm(x)
x = self.relu(x)
x = self.dropout(x)
return torch.t(x).reshape(x_shape)
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats, num_edges, activation=None):
super(GCNLayer, self).__init__()
self.linears = nn.Parameter(
torch.zeros(num_edges, out_feats, in_feats))
self.activation = activation
def gcn_message(self, edges):
hs = torch.bmm(self.linears, edges.src['h'])
return {'msg': hs}
def forward(self, g, inputs):
g.ndata['h'] = inputs
g.update_all(self.gcn_message, dgl.function.sum(msg='msg', out='h'))
h = g.ndata.pop('h')
if self.activation is not None:
h = self.activation(h)
return h
class DoubleGcnLayer(nn.Module):
def __init__(self, hidden_size, num_edges, num_nodes, skip):
super(DoubleGcnLayer, self).__init__()
self.skip = skip
self.gcn1 = GCNLayer(hidden_size, hidden_size,
num_edges, ActivationLayer(hidden_size*num_nodes))
self.gcn2 = GCNLayer(hidden_size, hidden_size,
num_edges, ActivationLayer(hidden_size*num_nodes))
def forward(self, g, x):
y = self.gcn1(g, x)
y = self.gcn2(g, y)
if self.skip:
out = x + y
out = y
return out
class GCN(nn.Module):
def __init__(self,
g: dgl.DGLGraph = build_pose_graph(),
super(GCN, self).__init__()
self.in_feats = in_feats
self.num_classes = num_classes
self.gcn_in = GCNLayer(in_feats, hidden_size,
g.number_of_edges(), ActivationLayer(hidden_size*g.number_of_nodes()))
self.hidden_layers = []
for l in range(num_hidden_layers):
hidden_size, g.number_of_edges(), g.number_of_nodes(), skip=skip))
self.hidden_layers = nn.ModuleList(self.hidden_layers)
self.gcn_out = GCNLayer(hidden_size, num_classes,
self.g = g
def forward(self, inputs):
# start = time.time()
inputs = torch.t(inputs)
inputs = inputs.reshape([self.g.number_of_nodes(), self.in_feats, -1])
h = self.gcn_in(self.g, inputs)
for hidden_l in self.hidden_layers:
h = hidden_l(self.g, h)
h = self.gcn_out(self.g, h)
h = h.reshape([self.g.number_of_nodes()*self.num_classes, -1])
# torch.cuda.synchronize()
# batch_time = time.time() - start
# print("FORWARD:{:.4f}".format(batch_time*1000))
# start = time.time()
return torch.t(h)
def main():
G = build_pose_graph()
print('We have %d nodes.' % G.number_of_nodes())
print('We have %d edges.' % G.number_of_edges())
net = GCN(2, 130, 2, 3, G)
net = net.cuda()
output = net(torch.randn(128, 32).cuda())
output = net.forward(torch.randn(128, 32).cuda())
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment