Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active November 4, 2019 17:33
Show Gist options
  • Save mtreviso/9d0188073c9b6c2ec5718d7bf95b7da9 to your computer and use it in GitHub Desktop.
Save mtreviso/9d0188073c9b6c2ec5718d7bf95b7da9 to your computer and use it in GitHub Desktop.
Method to select word pieces from BERT (first, mean, sum, max)
@staticmethod
def select_word_pieces(features, bounds, method='first'):
"""
Args:
features (torch.Tensor): output of BERT. Shape of (bs, ts, h_dim)
bounds (torch.LongTensor): the indexes where the word pieces start.
Shape of (bs, ts)
e.g. Welcome to the jungle -> Wel_ _come _to _the _jungle
bounds[0] = [0, 2, 3, 4]
indexes for padding positions are expected to be equal to -1
method (str): the strategy used to get a representation of a word
based on its word pices. Possible choices are:
'first' = take the vector on the position of the first word piece
'sum' = take the sum of the vectors of the word pieces
'mean' = take the average of the vectors of the word pieces
'max' = take the max of the vectors of the word pieces
Returns:
torch.Tensor (bs, original_sequence_length, h_dim)
"""
bs, _, hidden_dim = features.size()
r = torch.arange(bs, device=features.device).unsqueeze(1)
if method == 'first':
return features[r, bounds]
elif method == 'sum' or method == 'mean':
neg_one_indexes = bounds.new_zeros(bs, 1) - 1
extended_bounds = torch.cat((bounds[:, 1:], neg_one_indexes), dim=1)
last_idx = (extended_bounds != -1).sum(dim=1).unsqueeze(-1) - 1
extended_bounds[r, last_idx + 1] = extended_bounds[r, last_idx] + 1
shifted_bounds = extended_bounds - 1
cumsum = features.cumsum(dim=1)
cumsum = cumsum[r, shifted_bounds]
zero_values = cumsum.new_zeros(bs, 1, hidden_dim)
shifted_cumsum = torch.cat((zero_values, cumsum[:, :-1]), dim=1)
selected_pieces = cumsum - shifted_cumsum
if method == 'mean':
lens = shifted_bounds + 1 - bounds
lens[lens == 0] = 1 # we should not have a case where lens_ij=0
selected_pieces = selected_pieces / lens.unsqueeze(-1).float()
return selected_pieces
elif method == 'max':
max_bounds_size = (bounds != -1).sum(1).max().item()
max_wordpieces = torch.zeros(
bs, max_bounds_size, hidden_dim, device=bounds.device
)
for i in range(bs):
bounds_len = (bounds[i] != -1).sum().item()
valid_bounds = bounds[i, :bounds_len].tolist()
valid_bounds.append(valid_bounds[-1] + 1)
slices = zip(valid_bounds[:-1], valid_bounds[1:])
for j, (k1, k2) in enumerate(slices):
x, _ = torch.max(features[i, k1:k2], dim=0)
max_wordpieces[i, j] = x
return max_wordpieces
else:
raise Exception('Method {} is not implemented'.format(method))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment