-
-
Save davideuler/48c40c653592ed39045e976e99ab7dcc to your computer and use it in GitHub Desktop.
simple static kv cache script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
device = "cuda" | |
# Copied from the gpt-fast repo | |
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
logits = logits / max(temperature, 1e-5) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
probs = logits_to_probs(logits[:, -1], temperature, top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def decode_one_tokens(model, cur_token, cache_position): | |
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] | |
new_token = sample(logits,temperature=0.6, top_k=5)[0] | |
return new_token | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) | |
model = model.to(device).eval() | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
prompt = "My favourite condiment is" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
batch_size, sequence_length = input_ids.shape | |
max_cache_length = 2048 | |
max_new_tokens = 100 | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) | |
generated_ids[:,:sequence_length] = input_ids | |
cache_position = torch.tensor([sequence_length], device=device) | |
with torch.no_grad(): | |
for i in range(100): | |
if i == 0: # prefill uses vanilla model | |
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] | |
input_id = sample(logits, temperature=0.6, top_k=5)[0] | |
generated_ids[:,sequence_length] = input_id[:,0] | |
else: | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
input_id = decode_one_tokens(model, input_id.clone(), cache_position) | |
generated_ids.index_copy_(1, cache_position, input_id) | |
cache_position += 1 | |
print(tokenizer.batch_decode(generated_ids.long())) | |
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment