Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active January 3, 2021 06:10
Show Gist options
  • Save mtreviso/66aa2b61246e89047618b4aede71a002 to your computer and use it in GitHub Desktop.
Save mtreviso/66aa2b61246e89047618b4aede71a002 to your computer and use it in GitHub Desktop.
Code displayed on medium.
def forward(self, emissions, tags, mask=None):
"""Compute the negative log-likelihood. See `log_likelihood` method."""
nll = -self.log_likelihood(emissions, tags, mask=mask)
return nll
def log_likelihood(self, emissions, tags, mask=None):
"""Compute the probability of a sequence of tags given a sequence of
emissions scores.
Args:
emissions (torch.Tensor): Sequence of emissions for each label.
Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
(seq_len, batch_size, nb_labels) otherwise.
tags (torch.LongTensor): Sequence of labels.
Shape of (batch_size, seq_len) if batch_first is True,
(seq_len, batch_size) otherwise.
mask (torch.FloatTensor, optional): Tensor representing valid positions.
If None, all positions are considered valid.
Shape of (batch_size, seq_len) if batch_first is True,
(seq_len, batch_size) otherwise.
Returns:
torch.Tensor: the (summed) log-likelihoods of each sequence in the batch.
Shape of (1,)
"""
# fix tensors order by setting batch as the first dimension
if not self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.float)
scores = self._compute_scores(emissions, tags, mask=mask)
partition = self._compute_log_partition(emissions, mask=mask)
return torch.sum(scores - partition)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment