Using padding and prefill during inference in huggingface transformers
import re
import sys
import time
import tqdm
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, LlamaForCausalLM
# torch==2.0.1
# transformers==4.34.0
# tokenizers==0.14.0
# python: 3.10
# CUDA: 11.8
def main():
hf_auth_token = ''
tokens_per_batch = 20000
max_new_tokens = 5
for use_flash_attention_2 in [False, True]:
model, tokenizer = get_model(hf_auth_token, use_flash_attention_2=use_flash_attention_2)
instances = get_instances(tokenizer, tokens_per_batch, max_new_tokens)
print('\nnumber of instances:', len(instances))
print(f'\npadding False, prefill False, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches_equal_length(instances, tokens_per_batch)
prefill = None
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
print(f'\npadding False, prefill True, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches_equal_length(instances, tokens_per_batch)
prefill = get_prefill_input_ids(instances, model)
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
print(f'\npadding True, prefill False, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches(instances, tokens_per_batch)
prefill = None
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
print(f'\npadding True, prefill True, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches(instances, tokens_per_batch)
prefill = get_prefill_input_ids(instances, model)
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
def get_padded_inputs(list_of_input_ids, tokenizer, model, prefill=None):
if prefill is not None:
prefill_input_ids, prefill_key_values = prefill
num_prefill_tokens = len(prefill_input_ids)
list_of_input_ids = [input_ids[num_prefill_tokens:] for input_ids in list_of_input_ids]
max_length = max(len(input_ids) for input_ids in list_of_input_ids)
padded_list_of_input_ids = list()
attention_mask = list()
position_ids = list()
for input_ids in list_of_input_ids:
num_pad_tokens = max_length - len(input_ids)
padded_input_ids = [tokenizer.pad_token_id] * num_pad_tokens + input_ids
if prefill is not None:
# position ids are only needed for "non-prefill" input ids and are offset by the number of prefill tokens
_position_ids = [i + len(prefill_input_ids) for i in range(len(input_ids))]
padded_position_ids = [1] * num_pad_tokens + _position_ids
# attention mask is active for prefill tokens, not active for padding tokens, and again active for new input tokens
padded_attention_mask = [1] * len(prefill_input_ids) + [0] * num_pad_tokens + [1] * len(input_ids)
padded_position_ids = [1] * num_pad_tokens + list(range(len(input_ids)))
padded_attention_mask = [0] * num_pad_tokens + [1] * len(input_ids)
assert len(padded_position_ids) == len(padded_input_ids)
inputs = dict()
inputs['input_ids'] = torch.as_tensor(padded_list_of_input_ids).to(model.device)
inputs['position_ids'] = torch.as_tensor(position_ids).to(model.device)
inputs['attention_mask'] = torch.as_tensor(attention_mask).to(model.device)
inputs['past_key_values'] = None
if prefill is not None:
prefill_input_ids, prefill_key_values = prefill
# adapt cache to current batch size
past_key_values = get_tiled_cache(prefill_key_values, batch_size=len(list_of_input_ids))
inputs['past_key_values'] = past_key_values
return inputs
def get_tiled_cache(prefill_key_values, batch_size):
past_key_values = list()
for layer_cache in prefill_key_values:
key_cache = layer_cache[0]
value_cache = layer_cache[1]
assert key_cache.shape[0] == 1
assert value_cache.shape[0] == 1
key_cache = torch.tile(key_cache, (batch_size, 1, 1, 1))
value_cache = torch.tile(value_cache, (batch_size, 1, 1, 1))
past_key_values.append((key_cache, value_cache))
past_key_values = tuple(past_key_values)
return past_key_values
def update_generation_inputs(inputs, next_token_ids, past_key_values=None):
inputs['input_ids'] = next_token_ids
attention_mask_cat = torch.ones_like(inputs['input_ids'])
inputs['attention_mask'] =['attention_mask'], attention_mask_cat), dim=-1)
inputs['position_ids'] = (inputs['position_ids'][:, -1] + 1).unsqueeze(1)
inputs['past_key_values'] = past_key_values
return inputs
def get_instances(tokenizer, tokens_per_batch, max_new_tokens=5):
prompt = """<s>[INST] <<SYS>>
Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.
{user_msg_1} [/INST] {model_answer_1_start}"""
instances = list()
dataset1 = load_dataset("EleutherAI/asdiv")
dataset2 = load_dataset("ChilleD/SVAMP")
number_pattern = re.compile(r"-?(?<!\d)\d{1,10}(?:,\d{3})*(?:\.\d+)?%?(?!\d)")
for instance in concatenate_datasets([dataset1['validation'], dataset2['train'], dataset2['test']]):
answer = str(instance.get('answer', instance['Answer']))
numbers = number_pattern.findall(answer)
if len(numbers) == 1:
body = instance.get('body', instance['Body'])
question = instance.get('question', instance['Question'])
question = body + ' ' + question + '\nPlease respond only with the result.'
instance['text'] = prompt.format(user_msg_1=question, model_answer_1_start='The answer is: ')
instance['label'] = round(float(numbers[0].replace(",", "")), 2)
instance['input_ids'] = tokenizer.encode(instance['text'], add_special_tokens=False)
instance['num_tokens'] = len(instance['input_ids']) + max_new_tokens
assert instance['num_tokens'] <= tokens_per_batch
return instances
def get_model(hf_auth_token, model_name='meta-llama/Llama-2-7b-chat-hf', use_flash_attention_2=False):
if len(sys.argv) > 1:
hf_auth_token = sys.argv[1]
if hf_auth_token == '':
raise ValueError('Please provide your huggingface auth key in the script or as first command line argument')
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_auth_token)
model = LlamaForCausalLM.from_pretrained(model_name,
tokenizer.padding_side = 'left'
model.config.pad_token_id = tokenizer.pad_token_id = 0
model.config.bos_token_id = tokenizer.bos_token_id = 1
model.config.eos_token_id = tokenizer.eos_token_id = 2
return model, tokenizer
def get_prefill_input_ids(instances, model):
assert len(instances) > 1
num_prefill_tokens = None
for i, value in enumerate(instances[0]['input_ids']):
if any(ins['input_ids'][i] != value for ins in instances):
num_prefill_tokens = i
assert num_prefill_tokens is not None
prefill_input_ids = instances[0]['input_ids'][:num_prefill_tokens]
with torch.no_grad():
input_ids = torch.tensor([prefill_input_ids])
outputs = model(
past_key_values = outputs.past_key_values
return prefill_input_ids, past_key_values
def get_batches_equal_length(instances, tokens_per_batch):
instances = list(sorted(instances, key=lambda ins: -ins['num_tokens']))
batches = list()
batch = [instances[0]]
num_tokens = instances[0]['num_tokens']
for instance in instances[1:]:
if instance['num_tokens'] == num_tokens and num_tokens * (len(batch) + 1) <= tokens_per_batch:
batch = [instance]
num_tokens = instance['num_tokens']
return batches
def get_batches(instances, tokens_per_batch):
instances = list(sorted(instances, key=lambda ins: -ins['num_tokens']))
batches = list()
batch = [instances[0]]
num_tokens = instances[0]['num_tokens']
for instance in instances[1:]:
if num_tokens * (len(batch) + 1) <= tokens_per_batch:
batch = [instance]
num_tokens = instance['num_tokens']
return batches
def inference(model, tokenizer, batch, prefill=None, max_new_tokens=5):
input_ids = [instance['input_ids'] for instance in batch]
inputs = get_padded_inputs(input_ids, tokenizer, model, prefill)
generated_sequences = [list() for _ in batch]
is_finished = [False for _ in batch]
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(**inputs, use_cache=True)
past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
next_token_ids = torch.argmax(next_token_logits, dim=-1)
for output_id, instance in enumerate(batch):
if is_finished[output_id]:
next_token_id = next_token_ids[output_id].item()
if next_token_id == tokenizer.eos_token_id:
is_finished[output_id] = True
inputs = update_generation_inputs(inputs=inputs,
if all(is_finished):
for instance, generated_sequence in zip(batch, generated_sequences):
output_text = tokenizer.decode(generated_sequence, skip_special_tokens=True)
instance['response'] = output_text
def evaluate(instances):
number_pattern = re.compile(r"-?(?<!\d)\d{1,10}(?:,\d{3})*(?:\.\d+)?%?(?!\d)")
correct = 0
for instance in instances:
response = instance['response']
pred = number_pattern.findall(response)
if len(pred) > 0:
pred = round(float(pred[0].replace(",", "")), 2)
if pred == instance['label']:
correct += 1
accuracy = correct / len(instances) * 100
return round(accuracy, 1)
if __name__ == '__main__':
With the right attention mask and position_ids you can use padding and prefill tokens in huggingface transformers. This speeds up batched inference, especially if each instance has the same system prompt prepended.

Run with

python your_huggingface_auth_key

number of instances: 2089

padding False, prefill False, flash_attention_2 False
duration in seconds: 56
accuracy: 42.5

padding False, prefill True, flash_attention_2 False
duration in seconds: 35
accuracy: 42.5

padding True, prefill False, flash_attention_2 False
duration in seconds: 49
accuracy: 42.5

padding True, prefill True, flash_attention_2 False
duration in seconds: 27
accuracy: 42.5

Also works for flash attention, although I don't see additional speed ups.

padding False, prefill False, flash_attention_2 True
duration in seconds: 57
accuracy: 42.5

padding False, prefill True, flash_attention_2 True
duration in seconds: 35
accuracy: 42.5

padding True, prefill False, flash_attention_2 True
duration in seconds: 48
accuracy: 42.5

padding True, prefill True, flash_attention_2 True
duration in seconds: 27
accuracy: 42.5

