Skip to content

Instantly share code, notes, and snippets.

@fzyzcjy
Created December 20, 2022 03:07
Show Gist options
  • Save fzyzcjy/fab4bf82c62f23b3432123c84f14a2c6 to your computer and use it in GitHub Desktop.
Save fzyzcjy/fab4bf82c62f23b3432123c84f14a2c6 to your computer and use it in GitHub Desktop.
Speed up HuggingFace beam search by 10x
import warnings
from collections import UserDict, defaultdict
from typing import Optional, Tuple, Any
import torch
from transformers import BeamScorer, BeamSearchScorer
from transformers.generation import BeamHypotheses
from ...utils.torch_utils import first_several_nonzero_indices
class MyBeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done: torch.Tensor = \
torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
" one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
self.t_dict = defaultdict(lambda: 0.0)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
# t0 = default_timer()
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
f"size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
batch_beam_indices = torch.arange(batch_size, device=device)[:, None] * self.group_size + next_indices
# self.t_dict['prepare'] += default_timer() - t0
# t0 = default_timer()
# for eos
is_eos_and_non_done = (~self._done[:, None]) & (next_tokens == (eos_token_id or -42))
# self.t_dict['is_eos_and_non_done_sum'] += is_eos_and_non_done.sum().cpu().item()
# self.t_dict['is_eos_and_non_done_nonzero'] += int(is_eos_and_non_done.sum().cpu().item() > 0)
# self.t_dict['is_eos_and_non_done_count'] += 1
# self.t_dict['for-eos-create-a'] += default_timer() - t0
# t0 = default_timer()
next_indices_selected = next_indices[is_eos_and_non_done]
# self.t_dict['for-eos-create-b'] += default_timer() - t0
# t0 = default_timer()
next_scores_selected = next_scores[is_eos_and_non_done]
# self.t_dict['for-eos-create-c'] += default_timer() - t0
# t0 = default_timer()
is_eos_and_non_done_indices = is_eos_and_non_done.nonzero()
# self.t_dict['for-eos-create-d'] += default_timer() - t0
# t0 = default_timer()
next_indices_selected = next_indices_selected.cpu().numpy()
next_scores_selected = next_scores_selected.cpu().numpy()
is_eos_and_non_done_indices = is_eos_and_non_done_indices.cpu().numpy()
# self.t_dict['for-eos-to-cpu'] += default_timer() - t0
# t0 = default_timer()
for i, (batch_idx, beam_token_rank) in enumerate(is_eos_and_non_done_indices):
batch_beam_idx = batch_idx * self.group_size + next_indices_selected[i]
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
self._beam_hyps[batch_idx].add(
input_ids[batch_beam_idx].clone(),
next_scores_selected[i].item(),
beam_indices=beam_index,
)
# self.t_dict['for-eos-loop'] += default_timer() - t0
# t0 = default_timer()
# for non-eos
first_several_non_eos = first_several_nonzero_indices(
(next_tokens != (eos_token_id or -42)).int(), batch_enable=~self._done, k=self.num_beams)
next_beam_scores[:] = next_scores[first_several_non_eos].reshape((-1, self.num_beams))
next_beam_tokens[:] = next_tokens[first_several_non_eos].reshape((-1, self.num_beams))
next_beam_indices[:] = batch_beam_indices[first_several_non_eos].reshape((-1, self.num_beams))
# self.t_dict['for-non-eos'] += default_timer() - t0
# t0 = default_timer()
# those who are `done`
next_beam_scores[self._done, :] = 0
if pad_token_id is not None:
next_beam_tokens[self._done, :] = pad_token_id
next_beam_indices[self._done, :] = 0
# Check if we are done so that we can save a pad step if all(done)
next_scores_max = next_scores.max(dim=1)[0].cpu().numpy()
self._done |= torch.tensor([
beam_hyp.is_done(next_scores_max[batch_idx], cur_len)
for batch_idx, beam_hyp in enumerate(self._beam_hyps)
], device=device)
# self.t_dict['done-related'] += default_timer() - t0
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append hyp to lists
best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
if len(best_indices) > 0 and best_indices[0] is not None:
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
else:
indices = None
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
if indices is not None:
indices.fill_(-1)
# fill with hypotheses and eos_token_id if the latter fits in
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)
class BeamSearchScorerForComparison:
def __init__(self, **kwargs: Any):
self.ours = MyBeamSearchScorer(**kwargs)
self.theirs = BeamSearchScorer(**kwargs)
@property
def _beam_hyps(self):
assert len(self.ours._beam_hyps) == len(self.theirs._beam_hyps)
return self.ours._beam_hyps
@property
def num_beams(self):
assert self.ours.num_beams == self.theirs.num_beams
return self.ours.num_beams
@property
def is_done(self):
assert self.ours.is_done == self.theirs.is_done
return self.ours.is_done
def process(self, *args: Any, **kwargs: Any):
ours_output = self.ours.process(*args, **kwargs)
theirs_output = self.theirs.process(*args, **kwargs)
assert isinstance(ours_output, UserDict) and isinstance(theirs_output, UserDict)
self._check_output_equality(ours_output, theirs_output)
self._check_state_equality()
return theirs_output
def finalize(self, *args: Any, **kwargs: Any):
ours_output = self.ours.finalize(*args, **kwargs)
theirs_output = self.theirs.finalize(*args, **kwargs)
assert isinstance(ours_output, UserDict) and isinstance(theirs_output, UserDict)
self._check_output_equality(ours_output, theirs_output)
self._check_state_equality()
return theirs_output
@staticmethod
def _check_output_equality(ours_output: UserDict, theirs_output: UserDict):
assert set(ours_output.keys()) == set(theirs_output.keys())
for k in ours_output.keys():
assert BeamSearchScorerForComparison._tensor_equals(ours_output[k], theirs_output[k]), \
f'output not equal. key={k} ' \
f'ours_output={ours_output[k]} theirs_output={theirs_output[k]} '
def _check_state_equality(self):
assert self.ours.is_done == self.theirs.is_done
@staticmethod
def _tensor_equals(a: Optional[torch.Tensor], b: Optional[torch.Tensor]):
return (a is None and b is None) or torch.allclose(a, b)
import inspect
import warnings
from collections import defaultdict, UserDict
from timeit import default_timer
from typing import Optional, Iterable, Union, Callable, List, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.distributions.constraints import Constraint
from transformers import LogitsProcessorList, StoppingCriteriaList, ConstrainedBeamSearchScorer, DisjunctiveConstraint, \
PhrasalConstraint, BeamScorer
from transformers.generation import validate_stopping_criteria, BeamHypotheses
from transformers.generation.utils import GenerateOutput, GenerationMixin, BeamSearchOutput, \
BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, SampleOutput, SampleEncoderDecoderOutput, \
SampleDecoderOnlyOutput
# copied and modified from: huggingface utils.py
from transformers.pytorch_utils import torch_int_div
from ...utils.huggingface.beam_search import MyBeamSearchScorer
class MyBeamSearchScorerViaNumpy(BeamScorer):
def __init__(
self,
batch_size: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
self.t_dict = defaultdict(lambda: 0.0)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
" one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
t0 = default_timer()
device = input_ids.device
input_ids = input_ids.cpu().numpy()
next_scores = next_scores.cpu().numpy()
next_tokens = next_tokens.cpu().numpy()
next_indices = next_indices.cpu().numpy()
self.t_dict['move_to_numpy'] += default_timer() - t0
t0 = default_timer()
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
f"size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
# next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
# next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
# next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
next_beam_scores = np.zeros((batch_size, self.group_size), dtype=next_scores.dtype)
next_beam_tokens = np.zeros((batch_size, self.group_size), dtype=next_tokens.dtype)
next_beam_indices = np.zeros((batch_size, self.group_size), dtype=next_indices.dtype)
self.t_dict['create_array'] += default_timer() - t0
t0 = default_timer()
# print(
# f'beamscorer.process: '
# f'beam_hyps={len(self._beam_hyps)} '
# f'next_scores={next_scores.shape} '
# f'next_tokens={next_tokens.shape} '
# f'next_indices={next_indices.shape} '
# f'pad_token_id={pad_token_id} '
# f'eos_token_id={eos_token_id} '
# )
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
# t0 = default_timer()
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# self.t_dict['if_done'] += default_timer() - t0
# t0 = default_timer()
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
beam_hyp.add(
input_ids[batch_beam_idx].copy(),
# input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
# self.t_dict['inner_loop'] += default_timer() - t0
# t0 = default_timer()
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
# self.t_dict['postpare'] += default_timer() - t0
self.t_dict['body'] += default_timer() - t0
t0 = default_timer()
ans = UserDict(
{
# "next_beam_scores": next_beam_scores.view(-1),
# "next_beam_tokens": next_beam_tokens.view(-1),
# "next_beam_indices": next_beam_indices.view(-1),
"next_beam_scores": torch.tensor(next_beam_scores, device=device).view(-1),
"next_beam_tokens": torch.tensor(next_beam_tokens, device=device).view(-1),
"next_beam_indices": torch.tensor(next_beam_indices, device=device).view(-1),
}
)
self.t_dict['ans'] += default_timer() - t0
t0 = default_timer()
return ans
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
device = input_ids.device
input_ids = input_ids.cpu().numpy()
final_beam_scores = final_beam_scores.cpu().numpy()
# final_beam_tokens = final_beam_tokens.cpu().numpy()
# final_beam_indices = final_beam_indices.cpu().numpy()
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
# select the best hypotheses
# sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
sent_lengths = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=int)
best = []
best_indices = []
# best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
best_scores = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.float32)
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append hyp to lists
best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = int(min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max)
# decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
decoded = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=int)
if len(best_indices) > 0 and best_indices[0] is not None:
# indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
indices = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=int)
else:
indices = None
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill(pad_token_id)
if indices is not None:
indices.fill(-1)
# fill with hypotheses and eos_token_id if the latter fits in
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
# "sequences": decoded,
# "sequence_scores": best_scores,
# "beam_indices": indices,
"sequences": torch.tensor(decoded, device=device),
"sequence_scores": torch.tensor(best_scores, device=device),
"beam_indices": torch.tensor(indices, device=device) if indices is not None else None,
}
)
# ChosenBeamSearchScorer = BeamSearchScorer
# ChosenBeamSearchScorer = BeamSearchScorerForComparison
ChosenBeamSearchScorer = MyBeamSearchScorer
# ChosenBeamSearchScorer = MyBeamSearchScorerViaNumpy
@torch.no_grad()
def my_generate(
self: GenerationMixin,
inputs: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
penalty_alpha: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
print('my_generate called')
# 0. Validate the `.generate()` call
self._validate_model_class()
self._validate_model_kwargs(model_kwargs.copy())
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id
if pad_token_id is None and eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# 2. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
# 3. Define other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0:
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to "
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
"using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing "
"`max_new_tokens`."
)
# 6. determine generation mode
is_constraint_gen_mode = constraints is not None or force_words_ids is not None
is_contrastive_search_gen_mode = (
top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0
)
is_greedy_gen_mode = (
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_sample_gen_mode = (
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
(num_beams > 1)
and (num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
if self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
" Please make sure that you have put `input_ids` to the"
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
" running `.generate()`.",
UserWarning,
)
# 7. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
)
# 8. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
)
# 9. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# 10. run greedy search
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_contrastive_search_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search."
)
return self.contrastive_search(
input_ids,
top_k=top_k,
penalty_alpha=penalty_alpha,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run sample
return my_sample(
self,
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_beam_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. prepare beam search scorer
beam_scorer = ChosenBeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
return my_beam_search(
self,
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = ChosenBeamSearchScorer(
batch_size=batch_size * num_return_sequences,
num_beams=num_beams,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams * num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam sample
return self.beam_sample(
input_ids,
beam_scorer,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_group_beam_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if num_beams % num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if typical_p is not None:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 10. prepare beam search scorer
beam_scorer = MyBeamSearchScorerViaNumpy(
batch_size=batch_size,
num_beams=num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
return self.group_beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_constraint_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
if do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
if num_beam_groups is not None and num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if constraints is not None:
final_constraints = constraints
if force_words_ids is not None:
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
f"of positive integers, but is {force_words_ids}."
)
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
typeerror()
for word_ids in force_words_ids:
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(not isinstance(token_ids, list) for token_ids in word_ids):
typeerror()
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in word_ids
):
typeerror()
constraint = DisjunctiveConstraint(word_ids)
else:
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
typeerror()
constraint = PhrasalConstraint(word_ids)
final_constraints.append(constraint)
# 10. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer(
constraints=final_constraints,
batch_size=batch_size,
num_beams=num_beams,
device=inputs_tensor.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam search
return self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
def my_sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
t_start = default_timer()
t_dict = defaultdict(lambda: 0.0)
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True:
t0 = default_timer()
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
t_dict['prepare_inputs'] += default_timer() - t0
t0 = default_timer()
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
t_dict['model'] += default_timer() - t0
t0 = default_timer()
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
t_dict['sample'] += default_timer() - t0
t0 = default_timer()
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
t_dict['misc'] += default_timer() - t0
print(
f'sample '
f't_total={default_timer() - t_start:.3f}s '
't_dict=' + str({k: f'{v:.3}s' for k, v in t_dict.items()})
)
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return SampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
def my_beam_search(
self: GenerationMixin,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
t_start = default_timer()
t_dict = defaultdict(lambda: 0.0)
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
t0 = default_timer()
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# t_dict['prepare_inputs'] += default_timer() - t0
# t0 = default_timer()
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# t_dict['model'] += default_timer() - t0
# t0 = default_timer()
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
t_dict['_'] += next_token_logits.sum().cpu().numpy()
t_dict['model_2'] += default_timer() - t0
t0 = default_timer()
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
# t_dict['_'] += next_token_scores.sum().cpu().numpy()
# t_dict['calc_next_token'] += default_timer() - t0
# t0 = default_timer()
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# t_dict['_'] += next_token_scores.sum().cpu().numpy()
# t_dict['calc_scores'] += default_timer() - t0
# t0 = default_timer()
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# t_dict['related_return_dict_in_generate'] += default_timer() - t0
# t0 = default_timer()
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# t_dict['_'] += next_token_scores.sum().cpu().numpy()
# t_dict['reshape'] += default_timer() - t0
# t0 = default_timer()
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
# t_dict['_'] += next_token_scores.sum().cpu().numpy()
# t_dict['topk'] += default_timer() - t0
# t0 = default_timer()
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# t_dict['sample'] += default_timer() - t0
# t0 = default_timer()
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
t_dict['scorer'] += default_timer() - t0
t0 = default_timer()
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
t_dict['updates'] += default_timer() - t0
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
print(
f'beam_search '
f't_total={default_timer() - t_start:.3f}s '
't_dict=' + str({k: f'{v:.3}s' for k, v in t_dict.items()})
# ' beam_scorer=' + str({k: f'{v:.3}s' for k, v in beam_scorer.t_dict.items()})
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequence_outputs["sequences"]
# NOTE copied from:
# * https://github.com/huggingface/transformers/blob/main/tests/generation/test_beam_search.py
# * 还有一点来自 https://github.com/huggingface/transformers/blob/main/tests/test_modeling_common.py#L2609
import random
import unittest
import torch
from transformers.generation import BeamHypotheses
from transformers.testing_utils import require_torch, torch_device
from ....utils.huggingface.beam_search import MyBeamSearchScorer
global_rng = random.Random()
def ids_tensor(shape, vocab_size, rng=None, name=None):
# Creates a random int32 tensor of the shape within the vocab size
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
class BeamSearchTester:
def __init__(
self,
parent,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_beam_scorer(self, **kwargs):
return MyBeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
beam_hyp = beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
# check correct typn
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
beam_hyp = beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
# check too many eos tokens
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# check all batches are done
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
# check
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2: self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size,
device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
expected_beam_indices = list(range(10))
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
self.parent.assertListEqual(
expected_beam_indices + [correct_idx],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
@require_torch
class BeamSearchTest(unittest.TestCase):
def setUp(self):
self.beam_search_tester = BeamSearchTester(self)
def test_beam_hypotheses(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_hypotheses(*inputs)
def test_beam_scorer_update(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scorer_update(*inputs)
def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment