Skip to content

Instantly share code, notes, and snippets.

@JohnGiorgi
Last active November 2, 2023 04:58
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JohnGiorgi/7472f3a523f53aed332ff2f8d6eff914 to your computer and use it in GitHub Desktop.
Save JohnGiorgi/7472f3a523f53aed332ff2f8d6eff914 to your computer and use it in GitHub Desktop.
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used as a classifier for binary relation classification. If you spot an error or have an improvement, let me know!
import torch
class BiaffineAttention(torch.nn.Module):
"""Implements a biaffine attention operator for binary relation classification.
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
as a classifier for binary relation classification.
Args:
in_features (int): The size of the feature dimension of the inputs.
out_features (int): The size of the feature dimension of the output.
Shape:
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
additional dimensisons.
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
additional dimensions.
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
of additional dimensions.
Examples:
>>> batch_size, in_features, out_features = 32, 100, 4
>>> biaffine_attention = BiaffineAttention(in_features, out_features)
>>> x_1 = torch.randn(batch_size, in_features)
>>> x_2 = torch.randn(batch_size, in_features)
>>> output = biaffine_attention(x_1, x_2)
>>> print(output.size())
torch.Size([32, 4])
"""
def __init__(self, in_features, out_features):
super(BiaffineAttention, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)
self.reset_parameters()
def forward(self, x_1, x_2):
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))
def reset_parameters(self):
self.bilinear.reset_parameters()
self.linear.reset_parameters()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment