Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Created September 10, 2019 17:23
Show Gist options
  • Save epwalsh/134d7f42268f1bcee466d12f5641defe to your computer and use it in GitHub Desktop.
Save epwalsh/134d7f42268f1bcee466d12f5641defe to your computer and use it in GitHub Desktop.
class CopyNetSeq2Seq(Model):
# snip...
def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
# NOTE: here `trimmed_source_length` refers to the input sequence length minus 2,
# so that the special START and END tokens in the source are ignored. We also need to
# ignore PAD tokens, but that happens elsewhere using a mask.
# shape: (batch_size, trimmed_source_length, encoder_output_dim)
trimmed_encoder_outputs = state["encoder_outputs"][:, 1:-1]
# shape: (batch_size, trimmed_source_length, decoder_output_dim)
copy_projection = self._output_copying_layer(trimmed_encoder_outputs)
# shape: (batch_size, trimmed_source_length, decoder_output_dim)
copy_projection = torch.tanh(copy_projection)
# shape: (batch_size, trimmed_source_length)
copy_scores = copy_projection.bmm(state["decoder_hidden"].unsqueeze(-1)).squeeze(-1)
return copy_scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment