Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Created September 10, 2019 18:38
Show Gist options
  • Save epwalsh/fc5c6e5c0c7102141fe83ad243a5b69c to your computer and use it in GitHub Desktop.
Save epwalsh/fc5c6e5c0c7102141fe83ad243a5b69c to your computer and use it in GitHub Desktop.
class CopyNetSeq2Seq(Model):
# snip...
def _get_ll_contrib(self,
generation_scores: torch.Tensor,
generation_scores_mask: torch.Tensor,
copy_scores: torch.Tensor,
target_tokens: torch.Tensor,
target_to_source: torch.Tensor,
copy_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the log-likelihood contribution from a single timestep.
Parameters
----------
generation_scores : ``torch.Tensor``
Shape: `(batch_size, target_vocab_size)`
generation_scores_mask : ``torch.Tensor``
Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's.
copy_scores : ``torch.Tensor``
Shape: `(batch_size, trimmed_source_length)`
target_tokens : ``torch.Tensor``
Shape: `(batch_size,)`
target_to_source : ``torch.Tensor``
Shape: `(batch_size, trimmed_source_length)`
copy_mask : ``torch.Tensor``
Shape: `(batch_size, trimmed_source_length)`
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
"""
_, target_size = generation_scores.size()
# The point of this mask is to just mask out all source token scores
# that just represent padding. We apply the mask to the concatenation
# of the generation scores and the copy scores to normalize the scores
# correctly during the softmax.
# shape: (batch_size, target_vocab_size + trimmed_source_length)
mask = torch.cat((generation_scores_mask, copy_mask), dim=-1)
# shape: (batch_size, target_vocab_size + trimmed_source_length)
all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
# Normalize generation and copy scores.
# shape: (batch_size, target_vocab_size + trimmed_source_length)
log_probs = util.masked_log_softmax(all_scores, mask)
# Calculate the log probability (`copy_log_probs`) for each token in the source sentence
# that matches the current target token. We use the sum of these copy probabilities
# for matching tokens in the source sentence to get the total probability
# for the target token. We also need to normalize the individual copy probabilities
# to create `selective_weights`, which are used in the next timestep to create
# a selective read state.
# shape: (batch_size, trimmed_source_length)
copy_log_probs = log_probs[:, target_size:] + (target_to_source.float() + 1e-45).log()
# Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
# we use a non-log softmax to get the normalized non-log copy probabilities.
selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source)
# This mask ensures that item in the batch has a non-zero generation probabilities
# for this timestep only when the gold target token is not OOV or there are no
# matching tokens in the source sentence.
# shape: (batch_size, 1)
gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float()
log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
# Now we get the generation score for the gold target token.
# shape: (batch_size, 1)
generation_log_probs = log_probs.gather(1, target_tokens.unsqueeze(1)) + log_gen_mask
# ... and add the copy score to get the step log likelihood.
# shape: (batch_size, 1 + trimmed_source_length)
combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
# shape: (batch_size,)
step_log_likelihood = util.logsumexp(combined_gen_and_copy)
return step_log_likelihood, selective_weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment