Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Last active February 14, 2024 13:11
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save younesbelkada/ada0d9c2c48ab034486dbaaf95d29fae to your computer and use it in GitHub Desktop.
Save younesbelkada/ada0d9c2c48ab034486dbaaf95d29fae to your computer and use it in GitHub Desktop.
Benchmark Mistral 7b model
import argparse
from mistral.cache import RotatingBufferCache
import torch
import inspect
from typing import List
from pathlib import Path
from mistral.model import Transformer
from mistral.tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
help="Model path",
required=True
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=512,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--num-batches",
type=int,
default=1,
help="Number of times to run the experiments",
)
return parser
@torch.inference_mode()
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, chunk_size: int = None, temperature: float = 0.7):
model = model.eval()
B, V = len(prompts), model.args.vocab_size
device = torch.device("cuda:0")
# Tokenize
encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts]
seqlens = [len(x) for x in encoded_prompts]
# Cache
cache_window = min(model.args.sliding_window, max(seqlens) + max_tokens)
cache = RotatingBufferCache(model.args.n_layers, model.args.max_batch_size, cache_window, model.args.n_kv_heads, model.args.head_dim)
cache.to(device=model.device, dtype=model.dtype)
cache.reset()
last_token_prelogits = None
# One chunk if size not specified
max_prompt_len = max(seqlens)
if chunk_size is None:
chunk_size = max_prompt_len
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
# Encode prompt by chunks
for s in range(0, max_prompt_len, chunk_size):
prompt_chunks = [p[s:s+chunk_size] for p in encoded_prompts]
assert all(len(p) > 0 for p in prompt_chunks)
prelogits = model.forward(
torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
cache,
seqlens=[len(p) for p in prompt_chunks]
)
offset = 0
for i_seq, sequence in enumerate(prompt_chunks):
offset += len(sequence)
last_token_prelogits = prelogits.index_select(0, torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1)
assert last_token_prelogits.shape == (B, V)
# decode
generated_tokens = []
for i_token in range(max_tokens):
next_token = torch.argmax(last_token_prelogits, dim=-1)
generated_tokens.append(next_token[:, None])
last_token_prelogits = model.forward(next_token, cache, seqlens=[1] * len(prompts))
assert last_token_prelogits.shape == (B, V)
end_event.record()
torch.cuda.synchronize()
latency_s = start_event.elapsed_time(end_event) * 1e-3
max_memory = torch.cuda.max_memory_allocated(device)
generated_words = []
if generated_tokens:
generated_tokens = torch.cat(generated_tokens, 1)
for i, x in enumerate(encoded_prompts):
generated_words.append(tokenizer.decode(x + generated_tokens[i].tolist()))
return generated_words, (latency_s, max_memory)
def get_text():
# This generates ~11K tokens
# Modify this method accordingly to try out different scenarios
text = ["""Summarize the following news article in detail:\n""" * 1000]
return text
def benchmark(model_path: str, max_tokens: int = 35, num_batches: int = 1):
tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
text = get_text()
transformer = Transformer.from_folder(Path(model_path), max_batch_size=len(text))
# Check if we are effecitively using mem efficient attention from xformers
assert "memory_efficient_attention" in inspect.getsource(transformer.layers[0].attention.forward), "You did not loaded the optimized model"
assert transformer.dtype == torch.float16
# Warmup
_ = generate(
["hi"],
transformer,
tokenizer,
max_tokens=10,
)
total_latency = 0
total_max_memory = 0
# Retrieve generation stats
for _ in range(num_batches):
_, stats = generate(
text,
transformer,
tokenizer,
max_tokens=max_tokens,
)
latency_s, max_memory = stats
total_latency += latency_s
total_max_memory += total_max_memory
mean_latency = total_latency / num_batches
print(f"Mean Latency: {mean_latency}")
print(f"{max_tokens / mean_latency} tokens / s")
print(f"Mean Max allocated memory: {max_memory / num_batches}")
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
benchmark(args.model_path, args.max_new_tokens)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment