Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active October 2, 2023 09:55
Show Gist options
  • Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Pointer network attention architecture in PyTorch
class PointerNetwork(nn.Module):
"""
From "Pointer Networks" by Vinyals et al. (2017)
Adapted from pointer-networks-pytorch by ast0414:
https://github.com/ast0414/pointer-networks-pytorch
Args:
n_hidden: The number of features to expect in the inputs.
"""
def __init__(
self,
n_hidden: int
):
super().__init__()
self.n_hidden = n_hidden
self.w1 = nn.Linear(n_hidden, n_hidden, bias=False)
self.w2 = nn.Linear(n_hidden, n_hidden, bias=False)
self.v = nn.Linear(n_hidden, 1, bias=False)
def forward(
self,
x_decoder: torch.Tensor,
x_encoder: torch.Tensor,
mask: torch.Tensor,
eps: float = 1e-16
) -> torch.Tensor:
"""
Args:
x_decoder: Encoding over the output tokens.
x_encoder: Encoding over the input tokens.
mask: Binary mask over the softmax input.
Shape:
x_decoder: (B, Ne, C)
x_encoder: (B, Nd, C)
mask: (B, Nd, Ne)
"""
# (B, Nd, Ne, C) <- (B, Ne, C)
encoder_transform = self.w1(x_encoder).unsqueeze(1).expand(
-1, x_decoder.shape[1], -1, -1)
# (B, Nd, 1, C) <- (B, Nd, C)
decoder_transform = self.w2(x_decoder).unsqueeze(2)
# (B, Nd, Ne) <- (B, Nd, Ne, C), (B, Nd, 1, C)
prod = self.v(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1)
# (B, Nd, Ne) <- (B, Nd, Ne)
log_score = masked_log_softmax(prod, mask, dim=-1, eps=eps)
return log_score
@Ririkosann
Copy link

As you said, it was fixed.
Thank you very much!

@Mason-McGough
Copy link
Author

Happy to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment