Skip to content

Instantly share code, notes, and snippets.

@liuyijiang1994
Created November 18, 2019 04:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save liuyijiang1994/81f7c53647bbaeefff60fbe97d622405 to your computer and use it in GitHub Desktop.
Save liuyijiang1994/81f7c53647bbaeefff60fbe97d622405 to your computer and use it in GitHub Desktop.
BiAffine
class Biaffine(nn.Module):
def __init__(self, in1_features, in2_features, out_features,
bias=(True, True)):
super(Biaffine, self).__init__()
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.bias = bias
self.linear_input_size = in1_features + int(bias[0])
self.linear_output_size = out_features * (in2_features + int(bias[1]))
self.linear = nn.Linear(in_features=self.linear_input_size,
out_features=self.linear_output_size,
bias=False)
self.reset_parameters()
def reset_parameters(self):
W = np.zeros((self.linear_output_size, self.linear_input_size), dtype=np.float32)
self.linear.weight.data.copy_(torch.from_numpy(W))
def forward(self, input1, input2):
batch_size, len1, dim1 = input1.size()
batch_size, len2, dim2 = input2.size()
if self.bias[0]:
ones = input1.data.new(batch_size, len1, 1).zero_().fill_(1)
input1 = torch.cat((input1, ones), dim=2)
dim1 += 1
if self.bias[1]:
ones = input2.data.new(batch_size, len2, 1).zero_().fill_(1)
input2 = torch.cat((input2, ones), dim=2)
dim2 += 1
# linear: dm1 -> dm2 x out
affine = self.linear(input1) # batch, len1, out_features x dm2
affine = affine.view(batch_size, len1 * self.out_features, dim2) # batch, len1 x out_features, dm2
input2 = torch.transpose(input2, 1, 2) # batch_size, dim2, len2
biaffine = torch.transpose(torch.bmm(affine, input2), 1, 2)
# batch, len1 x out, len2 -> batch, len2, len1 x out
biaffine = biaffine.contiguous().view(batch_size, len2, len1, self.out_features)
# batch, len2, len1 x out -> batch, len2, len1, out
biaffine = torch.transpose(biaffine, 1, 2).contiguous()
# batch, len1, len2, out
return biaffine
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment