Skip to content

Instantly share code, notes, and snippets.

@raphael-sch
Last active March 21, 2024 14:49
Show Gist options
  • Save raphael-sch/479289637e6f4138242c9caea0dc9b44 to your computer and use it in GitHub Desktop.
Save raphael-sch/479289637e6f4138242c9caea0dc9b44 to your computer and use it in GitHub Desktop.
Using padding and prefill during inference in huggingface transformers
import re
import sys
import time
import tqdm
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, LlamaForCausalLM
# torch==2.0.1
# transformers==4.34.0
# tokenizers==0.14.0
# python: 3.10
# GPU: NVIDIA RTX A6000
# CUDA: 11.8
def main():
hf_auth_token = ''
tokens_per_batch = 20000
max_new_tokens = 5
for use_flash_attention_2 in [False, True]:
model, tokenizer = get_model(hf_auth_token, use_flash_attention_2=use_flash_attention_2)
instances = get_instances(tokenizer, tokens_per_batch, max_new_tokens)
print('\nnumber of instances:', len(instances))
print(f'\npadding False, prefill False, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches_equal_length(instances, tokens_per_batch)
prefill = None
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
print(f'\npadding False, prefill True, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches_equal_length(instances, tokens_per_batch)
prefill = get_prefill_input_ids(instances, model)
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
print(f'\npadding True, prefill False, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches(instances, tokens_per_batch)
prefill = None
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
print(f'\npadding True, prefill True, flash_attention_2 {use_flash_attention_2}')
start_time = time.time()
batches = get_batches(instances, tokens_per_batch)
prefill = get_prefill_input_ids(instances, model)
for batch in tqdm.tqdm(batches):
inference(model, tokenizer, batch, prefill, max_new_tokens)
print('duration in seconds:', int(time.time() - start_time))
accuracy = evaluate(instances)
print('accuracy:', accuracy)
def get_padded_inputs(list_of_input_ids, tokenizer, model, prefill=None):
if prefill is not None:
prefill_input_ids, prefill_key_values = prefill
num_prefill_tokens = len(prefill_input_ids)
list_of_input_ids = [input_ids[num_prefill_tokens:] for input_ids in list_of_input_ids]
max_length = max(len(input_ids) for input_ids in list_of_input_ids)
padded_list_of_input_ids = list()
attention_mask = list()
position_ids = list()
for input_ids in list_of_input_ids:
num_pad_tokens = max_length - len(input_ids)
padded_input_ids = [tokenizer.pad_token_id] * num_pad_tokens + input_ids
if prefill is not None:
# position ids are only needed for "non-prefill" input ids and are offset by the number of prefill tokens
_position_ids = [i + len(prefill_input_ids) for i in range(len(input_ids))]
padded_position_ids = [1] * num_pad_tokens + _position_ids
# attention mask is active for prefill tokens, not active for padding tokens, and again active for new input tokens
padded_attention_mask = [1] * len(prefill_input_ids) + [0] * num_pad_tokens + [1] * len(input_ids)
else:
padded_position_ids = [1] * num_pad_tokens + list(range(len(input_ids)))
padded_attention_mask = [0] * num_pad_tokens + [1] * len(input_ids)
assert len(padded_position_ids) == len(padded_input_ids)
padded_list_of_input_ids.append(padded_input_ids)
position_ids.append(padded_position_ids)
attention_mask.append(padded_attention_mask)
inputs = dict()
inputs['input_ids'] = torch.as_tensor(padded_list_of_input_ids).to(model.device)
inputs['position_ids'] = torch.as_tensor(position_ids).to(model.device)
inputs['attention_mask'] = torch.as_tensor(attention_mask).to(model.device)
inputs['past_key_values'] = None
if prefill is not None:
prefill_input_ids, prefill_key_values = prefill
# adapt cache to current batch size
past_key_values = get_tiled_cache(prefill_key_values, batch_size=len(list_of_input_ids))
inputs['past_key_values'] = past_key_values
return inputs
def get_tiled_cache(prefill_key_values, batch_size):
past_key_values = list()
for layer_cache in prefill_key_values:
key_cache = layer_cache[0]
value_cache = layer_cache[1]
assert key_cache.shape[0] == 1
assert value_cache.shape[0] == 1
key_cache = torch.tile(key_cache, (batch_size, 1, 1, 1))
value_cache = torch.tile(value_cache, (batch_size, 1, 1, 1))
past_key_values.append((key_cache, value_cache))
past_key_values = tuple(past_key_values)
return past_key_values
def update_generation_inputs(inputs, next_token_ids, past_key_values=None):
inputs['input_ids'] = next_token_ids
attention_mask_cat = torch.ones_like(inputs['input_ids'])
inputs['attention_mask'] = torch.cat((inputs['attention_mask'], attention_mask_cat), dim=-1)
inputs['position_ids'] = (inputs['position_ids'][:, -1] + 1).unsqueeze(1)
inputs['past_key_values'] = past_key_values
return inputs
def get_instances(tokenizer, tokens_per_batch, max_new_tokens=5):
prompt = """<s>[INST] <<SYS>>
Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.
<</SYS>>
{user_msg_1} [/INST] {model_answer_1_start}"""
instances = list()
dataset1 = load_dataset("EleutherAI/asdiv")
dataset2 = load_dataset("ChilleD/SVAMP")
number_pattern = re.compile(r"-?(?<!\d)\d{1,10}(?:,\d{3})*(?:\.\d+)?%?(?!\d)")
for instance in concatenate_datasets([dataset1['validation'], dataset2['train'], dataset2['test']]):
answer = str(instance.get('answer', instance['Answer']))
numbers = number_pattern.findall(answer)
if len(numbers) == 1:
body = instance.get('body', instance['Body'])
question = instance.get('question', instance['Question'])
question = body + ' ' + question + '\nPlease respond only with the result.'
instance['text'] = prompt.format(user_msg_1=question, model_answer_1_start='The answer is: ')
instance['label'] = round(float(numbers[0].replace(",", "")), 2)
instance['input_ids'] = tokenizer.encode(instance['text'], add_special_tokens=False)
instance['num_tokens'] = len(instance['input_ids']) + max_new_tokens
assert instance['num_tokens'] <= tokens_per_batch
instances.append(instance)
return instances
def get_model(hf_auth_token, model_name='meta-llama/Llama-2-7b-chat-hf', use_flash_attention_2=False):
if len(sys.argv) > 1:
hf_auth_token = sys.argv[1]
if hf_auth_token == '':
raise ValueError('Please provide your huggingface auth key in the script or as first command line argument')
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_auth_token)
model = LlamaForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16,
device_map="auto",
token=hf_auth_token,
use_flash_attention_2=use_flash_attention_2
)
tokenizer.padding_side = 'left'
model.config.pad_token_id = tokenizer.pad_token_id = 0
model.config.bos_token_id = tokenizer.bos_token_id = 1
model.config.eos_token_id = tokenizer.eos_token_id = 2
return model, tokenizer
def get_prefill_input_ids(instances, model):
assert len(instances) > 1
num_prefill_tokens = None
for i, value in enumerate(instances[0]['input_ids']):
if any(ins['input_ids'][i] != value for ins in instances):
num_prefill_tokens = i
break
assert num_prefill_tokens is not None
prefill_input_ids = instances[0]['input_ids'][:num_prefill_tokens]
with torch.no_grad():
input_ids = torch.tensor([prefill_input_ids])
outputs = model(input_ids=input_ids.to(model.device))
past_key_values = outputs.past_key_values
return prefill_input_ids, past_key_values
def get_batches_equal_length(instances, tokens_per_batch):
instances = list(sorted(instances, key=lambda ins: -ins['num_tokens']))
batches = list()
batch = [instances[0]]
num_tokens = instances[0]['num_tokens']
for instance in instances[1:]:
if instance['num_tokens'] == num_tokens and num_tokens * (len(batch) + 1) <= tokens_per_batch:
batch.append(instance)
else:
batches.append(batch)
batch = [instance]
num_tokens = instance['num_tokens']
batches.append(batch)
return batches
def get_batches(instances, tokens_per_batch):
instances = list(sorted(instances, key=lambda ins: -ins['num_tokens']))
batches = list()
batch = [instances[0]]
num_tokens = instances[0]['num_tokens']
for instance in instances[1:]:
if num_tokens * (len(batch) + 1) <= tokens_per_batch:
batch.append(instance)
else:
batches.append(batch)
batch = [instance]
num_tokens = instance['num_tokens']
batches.append(batch)
return batches
def inference(model, tokenizer, batch, prefill=None, max_new_tokens=5):
input_ids = [instance['input_ids'] for instance in batch]
inputs = get_padded_inputs(input_ids, tokenizer, model, prefill)
generated_sequences = [list() for _ in batch]
is_finished = [False for _ in batch]
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(**inputs, use_cache=True)
past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
next_token_ids = torch.argmax(next_token_logits, dim=-1)
for output_id, instance in enumerate(batch):
if is_finished[output_id]:
continue
next_token_id = next_token_ids[output_id].item()
if next_token_id == tokenizer.eos_token_id:
is_finished[output_id] = True
continue
generated_sequences[output_id].append(next_token_id)
inputs = update_generation_inputs(inputs=inputs,
next_token_ids=next_token_ids.unsqueeze(-1),
past_key_values=past_key_values)
if all(is_finished):
break
for instance, generated_sequence in zip(batch, generated_sequences):
output_text = tokenizer.decode(generated_sequence, skip_special_tokens=True)
instance['response'] = output_text
def evaluate(instances):
number_pattern = re.compile(r"-?(?<!\d)\d{1,10}(?:,\d{3})*(?:\.\d+)?%?(?!\d)")
correct = 0
for instance in instances:
response = instance['response']
pred = number_pattern.findall(response)
if len(pred) > 0:
pred = round(float(pred[0].replace(",", "")), 2)
if pred == instance['label']:
correct += 1
accuracy = correct / len(instances) * 100
return round(accuracy, 1)
if __name__ == '__main__':
main()
@raphael-sch
Copy link
Author

raphael-sch commented Oct 13, 2023

With the right attention mask and position_ids you can use padding and prefill tokens in huggingface transformers. This speeds up batched inference, especially if each instance has the same system prompt prepended.

Run with

python run_padding_prefill.py your_huggingface_auth_key

number of instances: 2089

padding False, prefill False, flash_attention_2 False
duration in seconds: 56
accuracy: 42.5

padding False, prefill True, flash_attention_2 False
duration in seconds: 35
accuracy: 42.5

padding True, prefill False, flash_attention_2 False
duration in seconds: 49
accuracy: 42.5

padding True, prefill True, flash_attention_2 False
duration in seconds: 27
accuracy: 42.5

Also works for flash attention, although I don't see additional speed ups.

padding False, prefill False, flash_attention_2 True
duration in seconds: 57
accuracy: 42.5

padding False, prefill True, flash_attention_2 True
duration in seconds: 35
accuracy: 42.5

padding True, prefill False, flash_attention_2 True
duration in seconds: 48
accuracy: 42.5

padding True, prefill True, flash_attention_2 True
duration in seconds: 27
accuracy: 42.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment