Skip to content

Instantly share code, notes, and snippets.

@liuyijiang1994
Created November 18, 2019 04:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save liuyijiang1994/3927c8b47e3a101dc9bfd8358f1db07c to your computer and use it in GitHub Desktop.
Save liuyijiang1994/3927c8b47e3a101dc9bfd8358f1db07c to your computer and use it in GitHub Desktop.
GCN
class GCN(nn.Module):
""" A GCN/Contextualized GCN module operated on dependency graphs. """
def __init__(self, in_dim, mem_dim, num_layers, in_drop=0.5, out_drop=0.5, batch=True):
super(GCN, self).__init__()
self.layers = num_layers
self.mem_dim = mem_dim
self.in_dim = in_dim
self.in_drop = nn.Dropout(in_drop)
self.gcn_drop = nn.Dropout(out_drop)
# gcn layer
self.W = nn.ModuleList()
self.batch = batch
for layer in range(self.layers):
input_dim = self.in_dim if layer == 0 else self.mem_dim
self.W.append(nn.Linear(input_dim, self.mem_dim))
def conv_l2(self):
conv_weights = []
for w in self.W:
conv_weights += [w.weight, w.bias]
return sum([x.pow(2).sum() for x in conv_weights])
def forward(self, adj, token_encode):
'''
:param adj: batch, seqlen, seqlen
:param token_encode: batch, seqlen, dm
:return:
'''
# print('W[l]', self.W[0].weight.shape)
if not self.batch:
adj = adj.unsqueeze(0)
token_encode = token_encode.unsqueeze(0)
embs = self.in_drop(token_encode)
gcn_inputs = embs
# gcn layer
denom = adj.sum(2).unsqueeze(2) + 1
mask = (adj.sum(2) + adj.sum(1)).eq(0).unsqueeze(2)
for l in range(self.layers):
Ax = adj.bmm(gcn_inputs)
AxW = self.W[l](Ax)
AxW = AxW + self.W[l](gcn_inputs) # self loop
AxW = AxW / denom
gAxW = F.relu(AxW)
gcn_inputs = self.gcn_drop(gAxW) if l < self.layers - 1 else gAxW
print('gcn_inputs', gcn_inputs.shape)
return gcn_inputs, mask
def pool(h, mask, pool_type='max'):
if pool_type == 'max':
h = h.masked_fill(mask, -1e12)
return torch.max(h, 1)[0]
elif pool_type == 'avg':
h = h.masked_fill(mask, 0)
return h.sum(1) / (mask.size(1) - mask.float().sum(1))
else:
h = h.masked_fill(mask, 0)
return h.sum(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment