Created
September 10, 2019 17:23
-
-
Save epwalsh/134d7f42268f1bcee466d12f5641defe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CopyNetSeq2Seq(Model): | |
# snip... | |
def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor: | |
# NOTE: here `trimmed_source_length` refers to the input sequence length minus 2, | |
# so that the special START and END tokens in the source are ignored. We also need to | |
# ignore PAD tokens, but that happens elsewhere using a mask. | |
# shape: (batch_size, trimmed_source_length, encoder_output_dim) | |
trimmed_encoder_outputs = state["encoder_outputs"][:, 1:-1] | |
# shape: (batch_size, trimmed_source_length, decoder_output_dim) | |
copy_projection = self._output_copying_layer(trimmed_encoder_outputs) | |
# shape: (batch_size, trimmed_source_length, decoder_output_dim) | |
copy_projection = torch.tanh(copy_projection) | |
# shape: (batch_size, trimmed_source_length) | |
copy_scores = copy_projection.bmm(state["decoder_hidden"].unsqueeze(-1)).squeeze(-1) | |
return copy_scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment