Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active October 2, 2023 09:55
Show Gist options
  • Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Pointer network attention architecture in PyTorch
class PointerNetwork(nn.Module):
"""
From "Pointer Networks" by Vinyals et al. (2017)
Adapted from pointer-networks-pytorch by ast0414:
https://github.com/ast0414/pointer-networks-pytorch
Args:
n_hidden: The number of features to expect in the inputs.
"""
def __init__(
self,
n_hidden: int
):
super().__init__()
self.n_hidden = n_hidden
self.w1 = nn.Linear(n_hidden, n_hidden, bias=False)
self.w2 = nn.Linear(n_hidden, n_hidden, bias=False)
self.v = nn.Linear(n_hidden, 1, bias=False)
def forward(
self,
x_decoder: torch.Tensor,
x_encoder: torch.Tensor,
mask: torch.Tensor,
eps: float = 1e-16
) -> torch.Tensor:
"""
Args:
x_decoder: Encoding over the output tokens.
x_encoder: Encoding over the input tokens.
mask: Binary mask over the softmax input.
Shape:
x_decoder: (B, Ne, C)
x_encoder: (B, Nd, C)
mask: (B, Nd, Ne)
"""
# (B, Nd, Ne, C) <- (B, Ne, C)
encoder_transform = self.w1(x_encoder).unsqueeze(1).expand(
-1, x_decoder.shape[1], -1, -1)
# (B, Nd, 1, C) <- (B, Nd, C)
decoder_transform = self.w2(x_decoder).unsqueeze(2)
# (B, Nd, Ne) <- (B, Nd, Ne, C), (B, Nd, 1, C)
prod = self.v(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1)
# (B, Nd, Ne) <- (B, Nd, Ne)
log_score = masked_log_softmax(prod, mask, dim=-1, eps=eps)
return log_score
@Ririkosann
Copy link

Hello.
I am writing here to ask a question about the program.
I have read your article on Pointer Networks with Transformers.
It is very interesting as it combines pointer network and Transformer.
I actually ran the Google colab you linked to, but I got a ValueError: Inconsistent coordinate dimensionality when training.
How can I run it?

@Mason-McGough
Copy link
Author

Hi there, thank you for your interest in my article. Do you think it could be related to this question? The error message sounds familiar.

In either case, I believe it should be fixed now. Please give it a try and let me know if you still encounter issues.

@Ririkosann
Copy link

As you said, it was fixed.
Thank you very much!

@Mason-McGough
Copy link
Author

Happy to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment