Last active
May 17, 2024 02:26
-
-
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
simple static kv cache script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
device = "cuda" | |
# Copied from the gpt-fast repo | |
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
logits = logits / max(temperature, 1e-5) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
probs = logits_to_probs(logits[:, -1], temperature, top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def decode_one_tokens(model, cur_token, cache_position): | |
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] | |
new_token = sample(logits,temperature=0.6, top_k=5)[0] | |
return new_token | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) | |
model = model.to(device).eval() | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
prompt = "My favourite condiment is" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
batch_size, sequence_length = input_ids.shape | |
max_cache_length = 2048 | |
max_new_tokens = 100 | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) | |
generated_ids[:,:sequence_length] = input_ids | |
cache_position = torch.tensor([sequence_length], device=device) | |
with torch.no_grad(): | |
for i in range(100): | |
if i == 0: # prefill uses vanilla model | |
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] | |
input_id = sample(logits, temperature=0.6, top_k=5)[0] | |
generated_ids[:,sequence_length] = input_id[:,0] | |
else: | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
input_id = decode_one_tokens(model, input_id.clone(), cache_position) | |
generated_ids.index_copy_(1, cache_position, input_id) | |
cache_position += 1 | |
print(tokenizer.batch_decode(generated_ids.long())) | |
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"] |
Author
ArthurZucker
commented
Apr 25, 2024
via email
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03
Does this work ?
… On 25 Apr 2024, at 16:07, Billy Cao ***@***.***> wrote:
@aliencaocao commented on this gist.
Hi @ArthurZucker <https://github.com/ArthurZucker> the link is broken btw
—
Reply to this email directly, view it on GitHub <https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb#gistcomment-5035795> or unsubscribe <https://github.com/notifications/unsubscribe-auth/ALSYHVYKTWTBLYJLHEBRAH3Y7EE2BBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBSHAYTANBUU52HE2LHM5SXFJTDOJSWC5DF>.
You are receiving this email because you were mentioned.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
Yes, thanks
Thanks for providing this sample! I am building a customized MLP block with Triton2.2. Looks like my model cannot be processed by torch.compile
. It works well with CUDAGraph without cache though. Could you provide a sample code that builds a CUDA graph with static cache? The error I get is
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'layout'
target: triton_kernel_wrapper_functional
kwargs: {'kernel_idx': 0, 'grid': [(688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1)], 'kwargs': {'a_ptr': TensorBox(
View(
SliceView(
StorageBox(
ComputedBuffer(name='buf19', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 4096], stride=[4096, 4096, 1]), data=Pointwise(
'cuda',
torch.float16,
def inner_fn(index):
....
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Any suggestion on fixing this error is also helpful!!!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment