Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Created September 10, 2019 17:25
Show Gist options
  • Save epwalsh/d3ab95b16de55971b9c090c932c985f1 to your computer and use it in GitHub Desktop.
Save epwalsh/d3ab95b16de55971b9c090c932c985f1 to your computer and use it in GitHub Desktop.
class CopyNetSeq2Seq(Model):
# snip...
def _decoder_step(self,
last_predictions: torch.Tensor,
selective_weights: torch.Tensor,
state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# shape: (group_size, max_input_sequence_length, encoder_output_dim)
encoder_outputs_mask = state["source_mask"].float()
# shape: (group_size, target_embedding_dim)
embedded_input = self._target_embedder(last_predictions)
# shape: (group_size, max_input_sequence_length)
attentive_weights = self._attention(
state["decoder_hidden"], state["encoder_outputs"], encoder_outputs_mask)
# shape: (group_size, encoder_output_dim)
attentive_read = util.weighted_sum(state["encoder_outputs"], attentive_weights)
# shape: (group_size, encoder_output_dim)
selective_read = util.weighted_sum(state["encoder_outputs"][:, 1:-1], selective_weights)
# shape: (group_size, target_embedding_dim + encoder_output_dim * 2)
decoder_input = torch.cat((embedded_input, attentive_read, selective_read), -1)
# shape: (group_size, decoder_input_dim)
projected_decoder_input = self._input_projection_layer(decoder_input)
state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
projected_decoder_input,
(state["decoder_hidden"], state["decoder_context"]))
return state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment