Last active
January 3, 2021 06:10
-
-
Save mtreviso/66aa2b61246e89047618b4aede71a002 to your computer and use it in GitHub Desktop.
Code displayed on medium.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, emissions, tags, mask=None): | |
"""Compute the negative log-likelihood. See `log_likelihood` method.""" | |
nll = -self.log_likelihood(emissions, tags, mask=mask) | |
return nll | |
def log_likelihood(self, emissions, tags, mask=None): | |
"""Compute the probability of a sequence of tags given a sequence of | |
emissions scores. | |
Args: | |
emissions (torch.Tensor): Sequence of emissions for each label. | |
Shape of (batch_size, seq_len, nb_labels) if batch_first is True, | |
(seq_len, batch_size, nb_labels) otherwise. | |
tags (torch.LongTensor): Sequence of labels. | |
Shape of (batch_size, seq_len) if batch_first is True, | |
(seq_len, batch_size) otherwise. | |
mask (torch.FloatTensor, optional): Tensor representing valid positions. | |
If None, all positions are considered valid. | |
Shape of (batch_size, seq_len) if batch_first is True, | |
(seq_len, batch_size) otherwise. | |
Returns: | |
torch.Tensor: the (summed) log-likelihoods of each sequence in the batch. | |
Shape of (1,) | |
""" | |
# fix tensors order by setting batch as the first dimension | |
if not self.batch_first: | |
emissions = emissions.transpose(0, 1) | |
tags = tags.transpose(0, 1) | |
if mask is None: | |
mask = torch.ones(emissions.shape[:2], dtype=torch.float) | |
scores = self._compute_scores(emissions, tags, mask=mask) | |
partition = self._compute_log_partition(emissions, mask=mask) | |
return torch.sum(scores - partition) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment