This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# vocabulary = OrderedDict() | |
input_length = None | |
vocabulary_size = max(vocabulary.values()) + 1 | |
weights_w2v = list(map(Word2Vec.__getitem__, vocabulary.keys())) | |
embedding_size len(weights_w2v[0]) | |
nb_classes = 5 | |
# CNN hyperparms | |
nb_filter = 64 | |
filter_length = 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class CRF(nn.Module): | |
""" | |
Linear-chain Conditional Random Field (CRF). | |
Args: | |
nb_labels (int): number of labels in your tagset, including special symbols. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, emissions, tags, mask=None): | |
"""Compute the negative log-likelihood. See `log_likelihood` method.""" | |
nll = -self.log_likelihood(emissions, tags, mask=mask) | |
return nll | |
def log_likelihood(self, emissions, tags, mask=None): | |
"""Compute the probability of a sequence of tags given a sequence of | |
emissions scores. | |
Args: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _compute_scores(self, emissions, tags, mask): | |
"""Compute the scores for a given batch of emissions with their tags. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
tags (Torch.LongTensor): (batch_size, seq_len) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: Scores for each batch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _compute_log_partition(self, emissions, mask): | |
"""Compute the partition function in log-space using the forward-algorithm. | |
Args: | |
emissions (torch.Tensor): (batch_size, seq_len, nb_labels) | |
mask (Torch.FloatTensor): (batch_size, seq_len) | |
Returns: | |
torch.Tensor: the partition scores for each batch. | |
Shape of (batch_size,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decode(self, emissions, mask=None): | |
"""Find the most probable sequence of labels given the emissions using | |
the Viterbi algorithm. | |
Args: | |
emissions (torch.Tensor): Sequence of emissions for each label. | |
Shape (batch_size, seq_len, nb_labels) if batch_first is True, | |
(seq_len, batch_size, nb_labels) otherwise. | |
mask (torch.FloatTensor, optional): Tensor representing valid positions. | |
If None, all positions are considered valid. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# find the best sequence of labels for each sample in the batch | |
best_sequences = [] | |
emission_lengths = mask.int().sum(dim=1) | |
for i in range(batch_size): | |
# recover the original sentence length for the i-th sample in the batch | |
sample_length = emission_lengths[i].item() | |
# recover the max tag for the last timestep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@staticmethod | |
def select_word_pieces(features, bounds, method='first'): | |
""" | |
Args: | |
features (torch.Tensor): output of BERT. Shape of (bs, ts, h_dim) | |
bounds (torch.LongTensor): the indexes where the word pieces start. | |
Shape of (bs, ts) | |
e.g. Welcome to the jungle -> Wel_ _come _to _the _jungle | |
bounds[0] = [0, 2, 3, 4] | |
indexes for padding positions are expected to be equal to -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def create_bow(words, vocab_size, pad_id=None): | |
""" | |
Create a bag of words matrix using torch.sparse.FloatTensor. | |
Args: | |
words (torch.LongTensor): tensor containing ids for words in | |
your vocabulary. Shape of (batch_size, seq_len) | |
vocab_size (int): size of the words vocabulary (including special |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
OlderNewer