Skip to content

Instantly share code, notes, and snippets.

@astariul
Created March 11, 2019 06:24
Show Gist options
  • Save astariul/44e5d3eeac577815b0cf6f3f7ecdc61c to your computer and use it in GitHub Desktop.
Save astariul/44e5d3eeac577815b0cf6f3f7ecdc61c to your computer and use it in GitHub Desktop.
def _beam_search(self, batch, beam_width, max_len):
""" Beam search for predicting a sentence."""
batch_size = batch['input_t'].size(1)
with torch.no_grad():
encoder_hidden, encoder_final = self.model.encode(
batch['input_t'].transpose(0, 1),
batch['input_mask'].transpose(0, 1),
batch['input_len'])
prev_y = torch.ones(batch_size, 1).fill_(START_TOKEN_ID).type_as(
batch['input_t'].transpose(0, 1))
trg_mask = torch.ones_like(prev_y)
candidate = {
'prev_y': prev_y,
'output': [prev_y],
'attention': [],
'hidden': None,
'score': torch.tensor([1.0] * batch_size, device=DEVICE)
}
candidates = [candidate] # Start beam search with only 1 candidate
for i in range(max_len):
next_candidates = []
for candidate in candidates:
with torch.no_grad():
out, hidden, pre_output = self.model.decode(
encoder_hidden,
encoder_final,
batch['input_mask'].transpose(0, 1),
candidate['prev_y'],
trg_mask,
candidate['hidden'])
# we predict from the pre-output layer, which is
# a combination of Decoder state, prev emb, and context
prob = self.model.generator(pre_output[:, -1])
topb, next_word = prob.topk(beam_width) # [batch_size, beam_width]
candidate['attention'].append(self.model.decoder.attention.alphas.cpu().numpy())
for i in range(beam_width):
next_word_i = next_word[:, i].unsqueeze(-1)
topb_i = topb[:, i]
next_candidates.append({
'prev_y': next_word_i,
'output': candidate['output'] + [next_word_i],
'attention': candidate['attention'],
'hidden': hidden,
'score': candidate['score'] * topb_i
})
# For each current candidates, we add beam_width candidates
# So we have beam_width candidates after the first iter
# and beam_width^2 for every following iter
# Sort all candidates based on score
next_candidates.sort(key=lambda k: k['score'].sum(), reverse=True)
# Take only the beam_width best
candidates = next_candidates[:beam_width]
# Take the output / attention of the best candidate
output = candidates[0]['output']
attention_scores = candidates[0]['attention']
# Reorganize output
output = torch.cat(output, dim=-1).tolist()
# output = np.array(output)
first_stop = np.where(output==STOP_TOKEN_ID)[0]
if len(first_stop) > 0:
output = output[:first_stop[0]]
return output, np.concatenate(attention_scores, axis=1)
outputs, _ = self._beam_search(batch, self.bsw, max_len)
# Batch is :
# {
# 'input_t': batched_article, #[padded_seq_len, batch_size]
# 'target_t': batched_abstract,
# 'input_mask': article_mask,
# 'target_mask': abstract_mask,
# 'input_len': article_len,
# 'target_len': abstract_len
# }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment