Skip to content

Instantly share code, notes, and snippets.

@fastdaima
Forked from ArthurZucker/static_kv_cache.py
Created February 29, 2024 13:27
Show Gist options
  • Save fastdaima/c57e9d60dfadca157b4825b5c38fd953 to your computer and use it in GitHub Desktop.
Save fastdaima/c57e9d60dfadca157b4825b5c38fd953 to your computer and use it in GitHub Desktop.
simple static kv cache script
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
device = "cuda"
# Copied from the gpt-fast repo
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0]
new_token = sample(logits,temperature=0.6, top_k=5)[0]
return new_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16)
model = model.to(device).eval()
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompt = "My favourite condiment is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
batch_size, sequence_length = input_ids.shape
max_cache_length = 2048
max_new_tokens = 100
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device)
generated_ids[:,:sequence_length] = input_ids
cache_position = torch.tensor([sequence_length], device=device)
with torch.no_grad():
for i in range(100):
if i == 0: # prefill uses vanilla model
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0]
input_id = sample(logits, temperature=0.6, top_k=5)[0]
generated_ids[:,sequence_length] = input_id[:,0]
else:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
input_id = decode_one_tokens(model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, cache_position, input_id)
cache_position += 1
print(tokenizer.batch_decode(generated_ids.long()))
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment